[FIRRTL] Fold Mux with the same constant high and low (#1566)

Add a fold pattern for mux when the low and high are the same constant.
That is, 
x = mux (cond, <constant1> ,<constant1>) 
can be replaced with, 
x = constant1
This can enable IMConstProp to constant propagate through mux trees.
This commit is contained in:
Prithayan Barua 2021-08-14 02:50:12 -07:00 committed by GitHub
parent c36dd88568
commit 0375e732c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 0 deletions

View File

@ -1045,6 +1045,9 @@ OpFoldResult MuxPrimOp::fold(ArrayRef<Attribute> operands) {
if (auto lowCst = operands[2].dyn_cast_or_null<IntegerAttr>()) {
// mux(cond, c1, c2)
if (auto highCst = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (highCst.getType() == lowCst.getType() &&
highCst.getValue() == lowCst.getValue())
return highCst;
// mux(cond, 1, 0) -> cond
if (highCst.getValue().isOneValue() && lowCst.getValue().isNullValue() &&
getType() == sel().getType())

View File

@ -432,3 +432,59 @@ firrtl.circuit "rhs_sink_output_used_as_wire" {
firrtl.connect %d, %bar_d : !firrtl.uint<1>, !firrtl.uint<1>
}
}
firrtl.circuit "constRegReset" {
// CHECK-LABEL: firrtl.module @constRegReset
firrtl.module @constRegReset(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %cond: !firrtl.uint<1>, out %z: !firrtl.uint<8>) {
%c11_ui8 = firrtl.constant 11 : !firrtl.uint<8>
%r = firrtl.regreset %clock, %reset, %c11_ui8 : !firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>
%0 = firrtl.mux(%cond, %c11_ui8, %r) : (!firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<8>
firrtl.connect %r, %0 : !firrtl.uint<8>, !firrtl.uint<8>
// CHECK: %[[C13:.+]] = firrtl.constant 11
// CHECK: firrtl.connect %z, %[[C13]]
firrtl.connect %z, %r : !firrtl.uint<8>, !firrtl.uint<8>
}
}
firrtl.circuit "constRegReset2" {
// CHECK-LABEL: firrtl.module @constRegReset2
firrtl.module @constRegReset2(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %cond: !firrtl.uint<1>, out %z: !firrtl.uint<8>) {
%c11_ui8 = firrtl.constant 11 : !firrtl.uint<8>
%c11_ui4 = firrtl.constant 11 : !firrtl.uint<4>
%r = firrtl.regreset %clock, %reset, %c11_ui4 : !firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<8>
%0 = firrtl.mux(%cond, %c11_ui8, %r) : (!firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<8>
firrtl.connect %r, %0 : !firrtl.uint<8>, !firrtl.uint<8>
// CHECK: %[[C14:.+]] = firrtl.constant 11
// CHECK: firrtl.connect %z, %[[C14]]
firrtl.connect %z, %r : !firrtl.uint<8>, !firrtl.uint<8>
}
}
firrtl.circuit "regMuxTree" {
// CHECK-LABEL: firrtl.module @regMuxTree
firrtl.module @regMuxTree(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %cmd: !firrtl.uint<3>, out %z: !firrtl.uint<8>) {
%c7_ui8 = firrtl.constant 7 : !firrtl.uint<8>
%c2_ui8 = firrtl.constant 2 : !firrtl.uint<8>
%c2_ui3 = firrtl.constant 2 : !firrtl.uint<3>
%c1_ui3 = firrtl.constant 1 : !firrtl.uint<3>
%c7_ui4 = firrtl.constant 7 : !firrtl.uint<4>
%r = firrtl.regreset %clock, %reset, %c7_ui4 : !firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<8>
%0 = firrtl.orr %cmd : (!firrtl.uint<3>) -> !firrtl.uint<1>
%1 = firrtl.not %0 : (!firrtl.uint<1>) -> !firrtl.uint<1>
%2 = firrtl.not %1 : (!firrtl.uint<1>) -> !firrtl.uint<1>
%3 = firrtl.eq %cmd, %c1_ui3 : (!firrtl.uint<3>, !firrtl.uint<3>) -> !firrtl.uint<1>
%4 = firrtl.and %2, %3 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%5 = firrtl.not %3 : (!firrtl.uint<1>) -> !firrtl.uint<1>
%6 = firrtl.and %2, %5 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%7 = firrtl.eq %cmd, %c2_ui3 : (!firrtl.uint<3>, !firrtl.uint<3>) -> !firrtl.uint<1>
%8 = firrtl.and %6, %7 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%9 = firrtl.mux(%8, %c7_ui8, %r) : (!firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<8>
%10 = firrtl.mux(%4, %r, %9) : (!firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<8>
%11 = firrtl.mux(%1, %c7_ui8, %10) : (!firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<8>
firrtl.connect %r, %11 : !firrtl.uint<8>, !firrtl.uint<8>
firrtl.connect %z, %r : !firrtl.uint<8>, !firrtl.uint<8>
// CHECK: %[[c7_ui8:.+]] = firrtl.constant 7 : !firrtl.uint<8>
// CHECK: firrtl.connect %z, %[[c7_ui8]] : !firrtl.uint<8>, !firrtl.uint<8>
}
}