From 5d8cf69cf222ec8776b1680d4d1de7e0e2aa4a86 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Wed, 21 Aug 2024 12:25:13 +0900 Subject: [PATCH] [CombFolds] Don't canonicalize extract(shl(1, x)) if shift is multiply used (#7527) There is a canonicalization for `exract(c, shl(1, x))` to `x == c` but this canonicalization introduces a bunch of comparision to constants. This harms PPA when bitwidth is large (e.g. 16 bit shift introduce 2^16 icmp op). To prevent such regressions this commit imposes restriction regarding the number of uses for shift. --- lib/Dialect/Comb/CombFolds.cpp | 25 ++++++++++++++----------- test/Dialect/Comb/canonicalization.mlir | 23 ++++++++++++++++------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 96cb562b40..ececa63a30 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -693,17 +693,20 @@ LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) { // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is // extracted. if (cast(op.getType()).getWidth() == 1 && inputOp) - if (auto shlOp = dyn_cast(inputOp)) - if (auto lhsCst = shlOp.getOperand(0).getDefiningOp()) - if (lhsCst.getValue().isOne()) { - auto newCst = rewriter.create( - shlOp.getLoc(), - APInt(lhsCst.getValue().getBitWidth(), op.getLowBit())); - replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::eq, - shlOp->getOperand(1), newCst, - false); - return success(); - } + if (auto shlOp = dyn_cast(inputOp)) { + // Don't canonicalize if the shift is multiply used. + if (shlOp->hasOneUse()) + if (auto lhsCst = shlOp.getLhs().getDefiningOp()) + if (lhsCst.getValue().isOne()) { + auto newCst = rewriter.create( + shlOp.getLoc(), + APInt(lhsCst.getValue().getBitWidth(), op.getLowBit())); + replaceOpWithNewOpAndCopyName( + rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst, + false); + return success(); + } + } return failure(); } diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index faa7e40334..77d0eadbc9 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1221,17 +1221,26 @@ hw.module @test1560(in %value: i38, out a: i1) { } // CHECK-LABEL: hw.module @extractShift -hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1) { +hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1, out o3: i1, out o4: i1) { %c1 = hw.constant 1: i4 %0 = comb.shl %c1, %arg0 : i4 + %1 = comb.shl %c1, %arg0 : i4 + %2 = comb.shl %c1, %arg0 : i4 - // CHECK: %0 = comb.icmp eq %arg0, %c0_i4 : i4 - %1 = comb.extract %0 from 0 : (i4) -> i1 + // CHECK: %[[O1:.+]] = comb.icmp eq %arg0, %c0_i4 : i4 + %3 = comb.extract %0 from 0 : (i4) -> i1 - // CHECK: %1 = comb.icmp eq %arg0, %c2_i4 : i4 - %2 = comb.extract %0 from 2 : (i4) -> i1 - // CHECK: hw.output %0, %1 - hw.output %1, %2: i1, i1 + // CHECK: %[[O2:.+]] = comb.icmp eq %arg0, %c2_i4 : i4 + %4 = comb.extract %1 from 2 : (i4) -> i1 + + // CHECK: %[[O3:.+]] = comb.extract + %5 = comb.extract %2 from 2 : (i4) -> i1 + + // CHECK: %[[O4:.+]] = comb.extract + %6 = comb.extract %2 from 2 : (i4) -> i1 + + // CHECK: hw.output %[[O1]], %[[O2]], %[[O3]], %[[O4]] + hw.output %3, %4, %5, %6: i1, i1, i1, i1 } // CHECK-LABEL: hw.module @moduloZeroDividend