[CombFolds] Don't canonicalize extract(shl(1, x)) if shift is multiply used (#7527)

There is a canonicalization for `exract(c, shl(1, x))` to `x == c` but this
canonicalization introduces a bunch of comparision to constants. This harms
PPA when bitwidth is large (e.g. 16 bit shift introduce 2^16 icmp op). To prevent
such regressions this commit imposes restriction regarding the number of uses
for shift.
This commit is contained in:
Hideto Ueno 2024-08-21 12:25:13 +09:00 committed by GitHub
parent 29b1c1cc76
commit 5d8cf69cf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 18 deletions

View File

@ -693,17 +693,20 @@ LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
// `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
// extracted.
if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
if (auto shlOp = dyn_cast<ShlOp>(inputOp))
if (auto lhsCst = shlOp.getOperand(0).getDefiningOp<hw::ConstantOp>())
if (lhsCst.getValue().isOne()) {
auto newCst = rewriter.create<hw::ConstantOp>(
shlOp.getLoc(),
APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
shlOp->getOperand(1), newCst,
false);
return success();
}
if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
// Don't canonicalize if the shift is multiply used.
if (shlOp->hasOneUse())
if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
if (lhsCst.getValue().isOne()) {
auto newCst = rewriter.create<hw::ConstantOp>(
shlOp.getLoc(),
APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
replaceOpWithNewOpAndCopyName<ICmpOp>(
rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
false);
return success();
}
}
return failure();
}

View File

@ -1221,17 +1221,26 @@ hw.module @test1560(in %value: i38, out a: i1) {
}
// CHECK-LABEL: hw.module @extractShift
hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1) {
hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1, out o3: i1, out o4: i1) {
%c1 = hw.constant 1: i4
%0 = comb.shl %c1, %arg0 : i4
%1 = comb.shl %c1, %arg0 : i4
%2 = comb.shl %c1, %arg0 : i4
// CHECK: %0 = comb.icmp eq %arg0, %c0_i4 : i4
%1 = comb.extract %0 from 0 : (i4) -> i1
// CHECK: %[[O1:.+]] = comb.icmp eq %arg0, %c0_i4 : i4
%3 = comb.extract %0 from 0 : (i4) -> i1
// CHECK: %1 = comb.icmp eq %arg0, %c2_i4 : i4
%2 = comb.extract %0 from 2 : (i4) -> i1
// CHECK: hw.output %0, %1
hw.output %1, %2: i1, i1
// CHECK: %[[O2:.+]] = comb.icmp eq %arg0, %c2_i4 : i4
%4 = comb.extract %1 from 2 : (i4) -> i1
// CHECK: %[[O3:.+]] = comb.extract
%5 = comb.extract %2 from 2 : (i4) -> i1
// CHECK: %[[O4:.+]] = comb.extract
%6 = comb.extract %2 from 2 : (i4) -> i1
// CHECK: hw.output %[[O1]], %[[O2]], %[[O3]], %[[O4]]
hw.output %3, %4, %5, %6: i1, i1, i1, i1
}
// CHECK-LABEL: hw.module @moduloZeroDividend