diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index bec5af584f43..3dbb1ebebd7c 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1301,8 +1301,41 @@ static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, } // Move Lcssa PHIs to the right place. -static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerLatch, - BasicBlock *OuterLatch) { +static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerHeader, + BasicBlock *InnerLatch, BasicBlock *OuterHeader, + BasicBlock *OuterLatch, BasicBlock *OuterExit) { + + // Deal with LCSSA PHI nodes in the exit block of the inner loop, that are + // defined either in the header or latch. Those blocks will become header and + // latch of the new outer loop, and the only possible users can PHI nodes + // in the exit block of the loop nest or the outer loop header (reduction + // PHIs, in that case, the incoming value must be defined in the inner loop + // header). We can just substitute the user with the incoming value and remove + // the PHI. + for (PHINode &P : make_early_inc_range(InnerExit->phis())) { + assert(P.getNumIncomingValues() == 1 && + "Only loops with a single exit are supported!"); + + // Incoming values are guaranteed be instructions currently. + auto IncI = cast(P.getIncomingValueForBlock(InnerLatch)); + // Skip phis with incoming values from the inner loop body, excluding the + // header and latch. + if (IncI->getParent() != InnerLatch && IncI->getParent() != InnerHeader) + continue; + + assert(all_of(P.users(), + [OuterHeader, OuterExit, IncI, InnerHeader](User *U) { + return (cast(U)->getParent() == OuterHeader && + IncI->getParent() == InnerHeader) || + cast(U)->getParent() == OuterExit; + }) && + "Can only replace phis iff the uses are in the loop nest exit or " + "the incoming value is defined in the inner header (it will " + "dominate all loop blocks after interchanging)"); + P.replaceAllUsesWith(IncI); + P.eraseFromParent(); + } + SmallVector LcssaInnerExit; for (PHINode &P : InnerExit->phis()) LcssaInnerExit.push_back(&P); @@ -1315,31 +1348,39 @@ static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerLatch, // If a PHI node has users outside of InnerExit, it has a use outside the // interchanged loop and we have to preserve it. We move these to // InnerLatch, which will become the new exit block for the innermost - // loop after interchanging. For PHIs only used in InnerExit, we can just - // replace them with the incoming value. - for (PHINode *P : LcssaInnerExit) { - bool hasUsersOutside = false; - for (auto UI = P->use_begin(), E = P->use_end(); UI != E;) { - Use &U = *UI; - ++UI; - auto *Usr = cast(U.getUser()); - if (Usr->getParent() != InnerExit) { - hasUsersOutside = true; - continue; - } - U.set(P->getIncomingValueForBlock(InnerLatch)); - } - if (hasUsersOutside) - P->moveBefore(InnerLatch->getFirstNonPHI()); - else - P->eraseFromParent(); - } + // loop after interchanging. + for (PHINode *P : LcssaInnerExit) + P->moveBefore(InnerLatch->getFirstNonPHI()); // If the inner loop latch contains LCSSA PHIs, those come from a child loop // and we have to move them to the new inner latch. for (PHINode *P : LcssaInnerLatch) P->moveBefore(InnerExit->getFirstNonPHI()); + // Deal with LCSSA PHI nodes in the loop nest exit block. For PHIs that have + // incoming values from the outer latch or header, we have to add a new PHI + // in the inner loop latch, which became the exit block of the outer loop, + // after interchanging. + if (OuterExit) { + for (PHINode &P : OuterExit->phis()) { + if (P.getNumIncomingValues() != 1) + continue; + // Skip Phis with incoming values not defined in the outer loop's header + // and latch. Also skip incoming phis defined in the latch. Those should + // already have been updated. + auto I = dyn_cast(P.getIncomingValue(0)); + if (!I || ((I->getParent() != OuterLatch || isa(I)) && + I->getParent() != OuterHeader)) + continue; + + PHINode *NewPhi = dyn_cast(P.clone()); + NewPhi->setIncomingValue(0, P.getIncomingValue(0)); + NewPhi->setIncomingBlock(0, OuterLatch); + NewPhi->insertBefore(InnerLatch->getFirstNonPHI()); + P.setIncomingValue(0, NewPhi); + } + } + // Now adjust the incoming blocks for the LCSSA PHIs. // For PHIs moved from Inner's exit block, we need to replace Inner's latch // with the new latch. @@ -1442,7 +1483,8 @@ bool LoopInterchangeTransform::adjustLoopBranches() { restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, OuterLoopPreHeader); - moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopLatch, OuterLoopLatch); + moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopHeader, InnerLoopLatch, + OuterLoopHeader, OuterLoopLatch, InnerLoop->getExitBlock()); // For PHIs in the exit block of the outer loop, outer's latch has been // replaced by Inners'. OuterLoopLatchSuccessor->replacePhiUsesWith(OuterLoopLatch, InnerLoopLatch); diff --git a/llvm/test/Transforms/LoopInterchange/perserve-lcssa.ll b/llvm/test/Transforms/LoopInterchange/perserve-lcssa.ll new file mode 100644 index 000000000000..af61709873c0 --- /dev/null +++ b/llvm/test/Transforms/LoopInterchange/perserve-lcssa.ll @@ -0,0 +1,181 @@ +; RUN: opt < %s -loop-interchange -loop-interchange-threshold=-100 -verify-loop-lcssa -S | FileCheck %s + +; Test case for PR41725. The induction variables in the latches escape the +; loops and we must move some PHIs around. + +@a = common dso_local global i64 0, align 4 +@b = common dso_local global i64 0, align 4 +@c = common dso_local global [10 x [1 x i32 ]] zeroinitializer, align 16 + + +define void @test_lcssa_indvars1() { +; CHECK-LABEL: @test_lcssa_indvars1() +; CHECK-LABEL: inner.body.split: +; CHECK-NEXT: %0 = phi i64 [ %iv.outer.next, %outer.latch ] +; CHECK-NEXT: %iv.inner.next = add nsw i64 %iv.inner, -1 + +; CHECK-LABEL: exit: +; CHECK-NEXT: %v4.lcssa = phi i64 [ %0, %inner.body.split ] +; CHECK-NEXT: %v8.lcssa.lcssa = phi i64 [ %iv.inner.next, %inner.body.split ] +; CHECK-NEXT: store i64 %v8.lcssa.lcssa, i64* @b, align 4 +; CHECK-NEXT: store i64 %v4.lcssa, i64* @a, align 4 + +entry: + br label %outer.header + +outer.header: ; preds = %outer.latch, %entry + %iv.outer = phi i64 [ 0, %entry ], [ %iv.outer.next, %outer.latch ] + br label %inner.body + +inner.body: ; preds = %inner.body, %outer.header + %iv.inner = phi i64 [ 5, %outer.header ], [ %iv.inner.next, %inner.body ] + %v7 = getelementptr inbounds [10 x [1 x i32]], [10 x [1 x i32]]* @c, i64 0, i64 %iv.inner, i64 %iv.outer + store i32 0, i32* %v7, align 4 + %iv.inner.next = add nsw i64 %iv.inner, -1 + %v9 = icmp eq i64 %iv.inner, 0 + br i1 %v9, label %outer.latch, label %inner.body + +outer.latch: ; preds = %inner.body + %v8.lcssa = phi i64 [ %iv.inner.next, %inner.body ] + %iv.outer.next = add nuw nsw i64 %iv.outer, 1 + %v5 = icmp ult i64 %iv.outer, 2 + br i1 %v5, label %outer.header, label %exit + +exit: ; preds = %outer.latch + %v4.lcssa = phi i64 [ %iv.outer.next, %outer.latch ] + %v8.lcssa.lcssa = phi i64 [ %v8.lcssa, %outer.latch ] + store i64 %v8.lcssa.lcssa, i64* @b, align 4 + store i64 %v4.lcssa, i64* @a, align 4 + ret void +} + + +define void @test_lcssa_indvars2() { +; CHECK-LABEL: @test_lcssa_indvars2() +; CHECK-LABEL: inner.body.split: +; CHECK-NEXT: %0 = phi i64 [ %iv.outer, %outer.latch ] +; CHECK-NEXT: %iv.inner.next = add nsw i64 %iv.inner, -1 + +; CHECK-LABEL: exit: +; CHECK-NEXT: %v4.lcssa = phi i64 [ %0, %inner.body.split ] +; CHECK-NEXT: %v8.lcssa.lcssa = phi i64 [ %iv.inner, %inner.body.split ] +; CHECK-NEXT: store i64 %v8.lcssa.lcssa, i64* @b, align 4 +; CHECK-NEXT: store i64 %v4.lcssa, i64* @a, align 4 + +entry: + br label %outer.header + +outer.header: ; preds = %outer.latch, %entry + %iv.outer = phi i64 [ 0, %entry ], [ %iv.outer.next, %outer.latch ] + br label %inner.body + +inner.body: ; preds = %inner.body, %outer.header + %iv.inner = phi i64 [ 5, %outer.header ], [ %iv.inner.next, %inner.body ] + %v7 = getelementptr inbounds [10 x [1 x i32]], [10 x [1 x i32]]* @c, i64 0, i64 %iv.inner, i64 %iv.outer + store i32 0, i32* %v7, align 4 + %iv.inner.next = add nsw i64 %iv.inner, -1 + %v9 = icmp eq i64 %iv.inner.next, 0 + br i1 %v9, label %outer.latch, label %inner.body + +outer.latch: ; preds = %inner.body + %v8.lcssa = phi i64 [ %iv.inner, %inner.body ] + %iv.outer.next = add nuw nsw i64 %iv.outer, 1 + %v5 = icmp ult i64 %iv.outer.next, 2 + br i1 %v5, label %outer.header, label %exit + +exit: ; preds = %outer.latch + %v4.lcssa = phi i64 [ %iv.outer, %outer.latch ] + %v8.lcssa.lcssa = phi i64 [ %v8.lcssa, %outer.latch ] + store i64 %v8.lcssa.lcssa, i64* @b, align 4 + store i64 %v4.lcssa, i64* @a, align 4 + ret void +} + +define void @test_lcssa_indvars3() { +; CHECK-LABEL: @test_lcssa_indvars3() +; CHECK-LABEL: inner.body.split: +; CHECK-NEXT: %0 = phi i64 [ %iv.outer.next, %outer.latch ] +; CHECK-NEXT: %iv.inner.next = add nsw i64 %iv.inner, -1 + +; CHECK-LABEL: exit: +; CHECK-NEXT: %v4.lcssa = phi i64 [ %0, %inner.body.split ] +; CHECK-NEXT: %v8.lcssa.lcssa = phi i64 [ %iv.inner.next, %inner.body.split ] +; CHECK-NEXT: %v8.lcssa.lcssa.2 = phi i64 [ %iv.inner.next, %inner.body.split ] +; CHECK-NEXT: %r1 = add i64 %v8.lcssa.lcssa, %v8.lcssa.lcssa.2 +; CHECK-NEXT: store i64 %r1, i64* @b, align 4 +; CHECK-NEXT: store i64 %v4.lcssa, i64* @a, align 4 + + +entry: + br label %outer.header + +outer.header: ; preds = %outer.latch, %entry + %iv.outer = phi i64 [ 0, %entry ], [ %iv.outer.next, %outer.latch ] + br label %inner.body + +inner.body: ; preds = %inner.body, %outer.header + %iv.inner = phi i64 [ 5, %outer.header ], [ %iv.inner.next, %inner.body ] + %v7 = getelementptr inbounds [10 x [1 x i32]], [10 x [1 x i32]]* @c, i64 0, i64 %iv.inner, i64 %iv.outer + store i32 0, i32* %v7, align 4 + %iv.inner.next = add nsw i64 %iv.inner, -1 + %v9 = icmp eq i64 %iv.inner, 0 + br i1 %v9, label %outer.latch, label %inner.body + +outer.latch: ; preds = %inner.body + %v8.lcssa = phi i64 [ %iv.inner.next, %inner.body ] + ;%const.lcssa = phi i64 [ 111, %inner.body ] + %iv.outer.next = add nuw nsw i64 %iv.outer, 1 + %v5 = icmp ult i64 %iv.outer, 2 + br i1 %v5, label %outer.header, label %exit + +exit: ; preds = %outer.latch + %v4.lcssa = phi i64 [ %iv.outer.next, %outer.latch ] + %v8.lcssa.lcssa = phi i64 [ %v8.lcssa, %outer.latch ] + %v8.lcssa.lcssa.2 = phi i64 [ %v8.lcssa, %outer.latch ] + %r1 = add i64 %v8.lcssa.lcssa, %v8.lcssa.lcssa.2 + store i64 %r1, i64* @b, align 4 + store i64 %v4.lcssa, i64* @a, align 4 + ret void +} + + +; Make sure we do not crash for loops without reachable exits. +define void @no_reachable_exits() { +; Check we interchanged. +; CHECK-LABEL: @no_reachable_exits() { +; CHECK-NEXT: bb: +; CHECK-NEXT: br label %inner.ph +; CHECK-LABEL: outer.ph: +; CHECK-NEXT: br label %outer.header +; CHECK-LABEL: inner.ph: +; CHECK-NEXT: br label %inner.body +; CHECK-LABEL: inner.body: +; CHECK-NEXT: %tmp31 = phi i32 [ 0, %inner.ph ], [ %tmp6, %inner.body.split ] +; CHECK-NEXT: br label %outer.ph + +bb: + br label %outer.ph + +outer.ph: ; preds = %bb + br label %outer.header + +outer.header: ; preds = %outer.ph, %outer.latch + %tmp2 = phi i32 [ 0, %outer.ph ], [ %tmp8, %outer.latch ] + br i1 undef, label %inner.ph, label %outer.latch + +inner.ph: ; preds = %outer.header + br label %inner.body + +inner.body: ; preds = %inner.ph, %inner.body + %tmp31 = phi i32 [ 0, %inner.ph ], [ %tmp6, %inner.body] + %tmp5 = load i32*, i32** undef, align 8 + %tmp6 = add nsw i32 %tmp31, 1 + br i1 undef, label %inner.body, label %outer.latch + +outer.latch: ; preds = %inner.body, %outer.header + %tmp8 = add nsw i32 %tmp2, 1 + br i1 undef, label %outer.header, label %exit + +exit: ; preds = %outer.latch + unreachable +}