From adc7771f18c09d50742ee0e07584aa92404b2144 Mon Sep 17 00:00:00 2001 From: Duncan Sands Date: Tue, 23 Nov 2010 14:23:47 +0000 Subject: [PATCH] Exploit distributive laws (eg: And distributes over Or, Mul over Add, etc) in a fairly systematic way in instcombine. Some of these cases were already dealt with, in which case I removed the existing code. The case of Add has a bunch of funky logic which covers some of this plus a few variants (considers shifts to be a form of multiplication), which I didn't touch. The simplification performed is: A*B+A*C -> A*(B+C). The improvement is to do this in cases that were not already handled [such as A*B-A*C -> A*(B-C), which was reported on the mailing list], and also to do it more often by not checking for "only one use" if "B+C" simplifies. llvm-svn: 120024 --- llvm/lib/Transforms/InstCombine/InstCombine.h | 6 + .../InstCombine/InstCombineAddSub.cpp | 5 + .../InstCombine/InstCombineAndOrXor.cpp | 53 ++------- .../InstCombine/InstructionCombining.cpp | 111 ++++++++++++++++++ .../InstCombine/2010-11-23-Distributed.ll | 11 ++ 5 files changed, 144 insertions(+), 42 deletions(-) create mode 100644 llvm/test/Transforms/InstCombine/2010-11-23-Distributed.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombine.h b/llvm/lib/Transforms/InstCombine/InstCombine.h index 05846d0f9e17..b492777a4724 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombine.h +++ b/llvm/lib/Transforms/InstCombine/InstCombine.h @@ -290,6 +290,12 @@ private: /// operators which are associative or commutative. bool SimplifyAssociativeOrCommutative(BinaryOperator &I); + /// SimplifyDistributed - This tries to simplify binary operations which some + /// other binary operation distributes over (eg "A*B+A*C" -> "A*(B+C)" since + /// addition is distributed over by multiplication). Returns the result of + /// the simplification, or null if no simplification was performed. + Instruction *SimplifyDistributed(BinaryOperator &I); + /// SimplifyDemandedUseBits - Attempts to replace V with a simpler value /// based on the demanded bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index c04a6b2a627e..b2919d8833b2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -91,6 +91,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { I.hasNoUnsignedWrap(), TD)) return ReplaceInstUsesWith(I, V); + if (Instruction *NV = SimplifyDistributed(I)) // (A*B)+(A*C) -> A*(B+C) + return NV; if (Constant *RHSC = dyn_cast(RHS)) { if (ConstantInt *CI = dyn_cast(RHSC)) { @@ -548,6 +550,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Op0 == Op1) // sub X, X -> 0 return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + if (Instruction *NV = SimplifyDistributed(I)) // (A*B)-(A*C) -> A*(B-C) + return NV; + // If this is a 'B = x-(-A)', change to B = x+A. This preserves NSW/NUW. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index a2569be16bd9..e9d72a4153e0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -984,6 +984,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyAndInst(Op0, Op1, TD)) return ReplaceInstUsesWith(I, V); + if (Instruction *NV = SimplifyDistributed(I)) // (A|B)&(A|C) -> A|(B&C) + return NV; + // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -1692,6 +1695,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyOrInst(Op0, Op1, TD)) return ReplaceInstUsesWith(I, V); + if (Instruction *NV = SimplifyDistributed(I)) // (A&B)|(A&C) -> A&(B|C) + return NV; + // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -1766,7 +1772,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *C = 0, *D = 0; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - Value *V1 = 0, *V2 = 0, *V3 = 0; + Value *V1 = 0, *V2 = 0; C1 = dyn_cast(C); C2 = dyn_cast(D); if (C1 && C2) { // (A & C1)|(B & C2) @@ -1824,25 +1830,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } } - - // Check to see if we have any common things being and'ed. If so, find the - // terms for V1 & (V2|V3). - if (Op0->hasOneUse() || Op1->hasOneUse()) { - V1 = 0; - if (A == B) // (A & C)|(A & D) == A & (C|D) - V1 = A, V2 = C, V3 = D; - else if (A == D) // (A & C)|(B & A) == A & (B|C) - V1 = A, V2 = B, V3 = C; - else if (C == B) // (A & C)|(C & D) == C & (A|D) - V1 = C, V2 = A, V3 = D; - else if (C == D) // (A & C)|(B & C) == C & (A|B) - V1 = C, V2 = A, V3 = B; - - if (V1) { - Value *Or = Builder->CreateOr(V2, V3, "tmp"); - return BinaryOperator::CreateAnd(V1, Or); - } - } // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants. // Don't do this for vector select idioms, the code generator doesn't handle @@ -1979,6 +1966,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyXorInst(Op0, Op1, TD)) return ReplaceInstUsesWith(I, V); + if (Instruction *NV = SimplifyDistributed(I)) // (A&B)^(A&C) -> A&(B^C) + return NV; + // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -2172,29 +2162,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } - - // (A & B)^(C & D) - if ((Op0I->hasOneUse() || Op1I->hasOneUse()) && - match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Value(C), m_Value(D)))) { - // (X & Y)^(X & Y) -> (Y^Z) & X - Value *X = 0, *Y = 0, *Z = 0; - if (A == C) - X = A, Y = B, Z = D; - else if (A == D) - X = A, Y = B, Z = C; - else if (B == C) - X = B, Y = A, Z = D; - else if (B == D) - X = B, Y = A, Z = C; - - if (X) { - Value *NewOp = Builder->CreateXor(Y, Z, Op0->getName()); - return BinaryOperator::CreateAnd(NewOp, X); - } - } } - + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (ICmpInst *RHS = dyn_cast(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast(I.getOperand(0))) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index cfc418256aab..ef7430cd72e0 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -237,6 +237,117 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { } while (1); } +/// LeftDistributesOverRight - Whether "X LOp (Y ROp Z)" is always equal to +/// "(X LOp Y) ROp (Z LOp Z)". +static bool LeftDistributesOverRight(Instruction::BinaryOps LOp, + Instruction::BinaryOps ROp) { + switch (LOp) { + default: + return false; + + case Instruction::And: + // And distributes over Or and Xor. + switch (ROp) { + default: + return false; + case Instruction::Or: + case Instruction::Xor: + return true; + } + + case Instruction::Mul: + // Multiplication distributes over addition and subtraction. + switch (ROp) { + default: + return false; + case Instruction::Add: + case Instruction::Sub: + return true; + } + + case Instruction::Or: + // Or distributes over And. + switch (ROp) { + default: + return false; + case Instruction::And: + return true; + } + } +} + +/// RightDistributesOverLeft - Whether "(X LOp Y) ROp Z" is always equal to +/// "(X ROp Z) LOp (Y ROp Z)". +static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, + Instruction::BinaryOps ROp) { + if (Instruction::isCommutative(ROp)) + return LeftDistributesOverRight(ROp, LOp); + // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", + // but this requires knowing that the addition does not overflow and other + // such subtleties. + return false; +} + +/// SimplifyDistributed - This tries to simplify binary operations which some +/// other binary operation distributes over (eg "A*B+A*C" -> "A*(B+C)" since +/// addition is distributed over by multiplication). Returns the result of +/// the simplification, or null if no simplification was performed. +Instruction *InstCombiner::SimplifyDistributed(BinaryOperator &I) { + BinaryOperator *Op0 = dyn_cast(I.getOperand(0)); + BinaryOperator *Op1 = dyn_cast(I.getOperand(1)); + if (!Op0 || !Op1 || Op0->getOpcode() != Op1->getOpcode()) + return 0; + + // The instruction has the form "(A op' B) op (C op' D)". + Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); + Value *C = Op1->getOperand(0); Value *D = Op1->getOperand(1); + Instruction::BinaryOps OuterOpcode = I.getOpcode(); // op + Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op' + + // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? + bool LeftDistributes = LeftDistributesOverRight(InnerOpcode, OuterOpcode); + // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? + bool RightDistributes = RightDistributesOverLeft(OuterOpcode, InnerOpcode); + // Does "X op' Y" always equal "Y op' X"? + bool InnerCommutative = Instruction::isCommutative(InnerOpcode); + + if (LeftDistributes) + // Does the instruction have the form "(A op' B) op (A op' D)" or, in the + // commutative case, "(A op' B) op (C op' A)"? + if (A == C || (InnerCommutative && A == D)) { + if (A != C) + std::swap(C, D); + // Consider forming "A op' (B op D)". + // If "B op D" simplifies then it can be formed with no cost. + Value *RHS = SimplifyBinOp(OuterOpcode, B, D, TD); + // If "B op D" doesn't simplify then only proceed if both of the existing + // operations "A op' B" and "C op' D" will be zapped since no longer used. + if (!RHS && Op0->hasOneUse() && Op1->hasOneUse()) + RHS = Builder->CreateBinOp(OuterOpcode, B, D, Op1->getName()); + if (RHS) + return BinaryOperator::Create(InnerOpcode, A, RHS); + } + + if (RightDistributes) + // Does the instruction have the form "(A op' B) op (C op' B)" or, in the + // commutative case, "(A op' B) op (B op' D)"? + if (B == D || (InnerCommutative && B == C)) { + if (B != D) + std::swap(C, D); + // Consider forming "(A op C) op' B". + // If "A op C" simplifies then it can be formed with no cost. + Value *LHS = SimplifyBinOp(OuterOpcode, A, C, TD); + // If "A op C" doesn't simplify then only proceed if both of the existing + // operations "A op' B" and "C op' D" will be zapped since no longer used. + if (!LHS && Op0->hasOneUse() && Op1->hasOneUse()) + LHS = Builder->CreateBinOp(OuterOpcode, A, C, Op0->getName()); + if (LHS) + return BinaryOperator::Create(InnerOpcode, LHS, B); + } + + return 0; +} + // dyn_castNegVal - Given a 'sub' instruction, return the RHS of the instruction // if the LHS is a constant zero (which is the 'negate' form). // diff --git a/llvm/test/Transforms/InstCombine/2010-11-23-Distributed.ll b/llvm/test/Transforms/InstCombine/2010-11-23-Distributed.ll new file mode 100644 index 000000000000..13a5720dad23 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/2010-11-23-Distributed.ll @@ -0,0 +1,11 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s +define i32 @foo(i32 %x, i32 %y) { +; CHECK: @foo + %add = add nsw i32 %y, %x + %mul = mul nsw i32 %add, %y + %square = mul nsw i32 %y, %y + %res = sub i32 %mul, %square +; CHECK: %res = mul i32 %x, %y + ret i32 %res +; CHECK: ret i32 %res +}