[FIRRTL] no back-prop for width of mux selectors, support narrower (#6917)

* Allow mux selectors to be zero-width or 1-bit (for mux4).

This is legal per FIRRTL spec.

* InferWidths: mux no back-prop.

Fixes https://github.com/llvm/circt/issues/5444

* canonicalizers for small mux selectors

Co-authored-by: Schuyler Eldridge <schuyler.eldridge@gmail.com>
This commit is contained in:
Will Dietz 2024-04-17 07:43:17 -05:00 committed by GitHub
parent 936806f0f5
commit ea77394a85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 116 additions and 22 deletions

View File

@ -55,6 +55,12 @@ def IntTypeWidthGEQ32 : Constraint<CPred<
def IntTypeWidthGT32 : Constraint<CPred<
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() > type_cast<IntType>($1.getType()).getBitWidthOrSentinel()">>;
// sizeof(0) < X
class IntTypeWidthLTX<int X> : Constraint<CPred<
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() >= 0 &&"
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel()< " # X
>>;
// Constraint that enforces int types
def IntTypes : Constraint<CPred<"type_isa<IntType>($0.getType())">>;
@ -562,6 +568,35 @@ def MuxNEQ : Pat<
(MoveNameHint $old, (MuxPrimOp (EQPrimOp $a, $b), $y, $x)),
[(EqualTypes $x, $y), (KnownWidth $x)]>;
// mux(cond : u0, a, b) -> mux(0 : u1, a, b)
def MuxPadSel : Pat<
(MuxPrimOp:$old $cond, $a, $b),
(MoveNameHint $old, (MuxPrimOp
(ConstantOp
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
(returnType "$_builder.getType<UIntType>(1)")),
$a, $b)),
[(IntTypeWidthLTX<1> $cond)]>;
// mux2(cond : u0, a, b) -> mux2(0 : u1, a, b)
def Mux2PadSel : Pat<
(Mux2CellIntrinsicOp:$old $cond, $a, $b),
(MoveNameHint $old, (Mux2CellIntrinsicOp
(ConstantOp
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
(returnType "$_builder.getType<UIntType>(1)")),
$a, $b)),
[(IntTypeWidthLTX<1> $cond)]>;
// mux4(cond : u0/u1, a, b) -> mux4(pad(cond -> u2), a, b)
def Mux4PadSel : Pat<
(Mux4CellIntrinsicOp:$old $cond, $a, $b, $c, $d),
(MoveNameHint $old, (Mux4CellIntrinsicOp
(PadPrimOp $cond,
(NativeCodeCall<"$_builder.getI32IntegerAttr(2)">)),
$a, $b, $c, $d)),
[(IntTypeWidthLTX<2> $cond)]>;
def CatDoubleConst : Pat <
(CatPrimOp:$old $cst1, (CatPrimOp $cst2, $v)),
(MoveNameHint $old, (CatPrimOp (CatPrimOp $cst1, (AsUIntPrimOp $cst2)), (AsUIntPrimOp $v))),

View File

@ -756,7 +756,7 @@ def HeadPrimOp : PrimOp<"head"> {
}
def MuxPrimOp : PrimOp<"mux"> {
let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
PassiveType:$low);
let results = (outs PassiveType:$result);
@ -842,10 +842,12 @@ def Mux2CellIntrinsicOp : PrimOp<"int.mux2cell"> {
the inference process in the same way as a normal mux operation.
}];
let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
PassiveType:$low);
let results = (outs PassiveType:$result);
let hasCanonicalizer = true;
let assemblyFormat =
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
}
@ -861,11 +863,13 @@ def Mux4CellIntrinsicOp : PrimOp<"int.mux4cell"> {
the inference process as a sugar of mux operation chains.
}];
let arguments = (ins UInt2OrUnsizedType:$sel, PassiveType:$v3,
let arguments = (ins UIntLTE2OrUnsizedType:$sel, PassiveType:$v3,
PassiveType:$v2, PassiveType:$v1,
PassiveType:$v0);
let results = (outs PassiveType:$result);
let hasCanonicalizer = true;
let assemblyFormat =
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
}

View File

@ -128,6 +128,12 @@ class SizedUIntType<int width> : FIRRTLDialectType<
"type_cast<UIntType>($_self).getWidth() == " # width>,
width # "-bit uint", "::circt::firrtl::UIntType">;
class SizedUIntTypeLTE<int width> : FIRRTLDialectType<
CPred<"type_isa<UIntType>($_self) && "
"type_cast<UIntType>($_self).getWidth() <= " # width>,
"uint with width less than or equal to " # width # " bits",
"::circt::firrtl::UIntType">;
class NonConstSizedUIntType<int width> :
SizedUIntType<width>,
BuildableType<
@ -138,8 +144,8 @@ def UInt2Type : SizedUIntType<2>;
def UInt32Type : SizedUIntType<32>;
def NonConstUInt1Type : NonConstSizedUIntType<1>;
def UInt1OrUnsizedType : AnyTypeOf<[UInt1Type, UnsizedUIntType]>;
def UInt2OrUnsizedType : AnyTypeOf<[UInt2Type, UnsizedUIntType]>;
def UIntLTE1OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<1>, UnsizedUIntType]>;
def UIntLTE2OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<2>, UnsizedUIntType]>;
//===----------------------------------------------------------------------===//
// FIRRTL Types Predicates

View File

@ -1476,10 +1476,22 @@ public:
void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
results
.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
context);
}
void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux2PadSel>(context);
}
void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux4PadSel>(context);
}
OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {

View File

@ -1596,12 +1596,12 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
})
.Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
auto *sel = getExpr(op.getSel());
constrainTypes(sel, solver.known(1));
constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
})
.Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
auto *sel = getExpr(op.getSel());
constrainTypes(sel, solver.known(2));
constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
maximumOfTypes(op.getResult(), op.getResult(), op.getV0());

View File

@ -506,7 +506,9 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
out %out1: !firrtl.uint<1>,
out %out2: !firrtl.uint<0>,
out %out3: !firrtl.uint<1>,
out %out4: !firrtl.uint<4>) {
out %out4: !firrtl.uint<4>,
out %out5: !firrtl.uint<1>,
out %out6: !firrtl.uint<1>) {
// CHECK: firrtl.strictconnect %out, %in
%0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4>
@ -560,6 +562,15 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
// CHECK-NEXT: [[V2:%.+]] = firrtl.mux(%cond
// CHECK-NEXT: firrtl.strictconnect %out4, [[V2]]
firrtl.connect %out4, %15 : !firrtl.uint<4>, !firrtl.uint<4>
// CHECK-NEXT: firrtl.strictconnect %out5, %val2
%16 = firrtl.mux (%val0, %val1, %val2) : (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out5, %16 : !firrtl.uint<1>
// CHECK-NEXT: %[[SEL:.+]] = firrtl.pad %val1, 2 : (!firrtl.uint<1>) -> !firrtl.uint<2>
// CHECK-NEXT: mux4cell(%[[SEL]],
%17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out6, %17 : !firrtl.uint<1>
}
// CHECK-LABEL: firrtl.module @Pad

View File

@ -186,3 +186,29 @@ firrtl.circuit "NoWidthEnum" {
firrtl.module @NoWidthEnum(out %o: !firrtl.enum<Some: uint>) {
}
}
// -----
firrtl.circuit "MuxSelBackProp" {
firrtl.module @MuxSelBackProp() {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
// expected-error @below {{uninferred width: wire is unconstrained}}
%0 = firrtl.wire : !firrtl.uint
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
}
}
// -----
firrtl.circuit "MuxSelTooWide" {
firrtl.module @MuxSelTooWide() {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
// expected-error @below {{uninferred width: wire cannot satisfy all width requirements}}
%0 = firrtl.wire : !firrtl.uint
// expected-note @below {{width is constrained to be at most 1 here:}}
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// expected-note @below {{width is constrained to be at least 2 here:}}
firrtl.connect %0, %c2_ui2 : !firrtl.uint, !firrtl.uint<2>
}
}

View File

@ -364,20 +364,18 @@ firrtl.circuit "Foo" {
firrtl.module @MuxOp() {
// CHECK: %0 = firrtl.wire : !firrtl.uint<2>
// CHECK: %1 = firrtl.wire : !firrtl.uint<3>
// CHECK: %2 = firrtl.wire : !firrtl.uint<1>
// CHECK: %2 = firrtl.wire : !firrtl.uint<0>
// CHECK: %3 = firrtl.mux{{.*}} -> !firrtl.uint<3>
%0 = firrtl.wire : !firrtl.uint
%1 = firrtl.wire : !firrtl.uint
%2 = firrtl.wire : !firrtl.uint
%3 = firrtl.mux(%2, %0, %1) : (!firrtl.uint, !firrtl.uint, !firrtl.uint) -> !firrtl.uint
// CHECK: %4 = firrtl.wire : !firrtl.uint<1>
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%4 = firrtl.wire : !firrtl.uint
%5 = firrtl.mux(%4, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
%c2_ui3 = firrtl.constant 2 : !firrtl.uint<3>
%c0_ui0 = firrtl.constant 0 : !firrtl.uint<0>
firrtl.connect %0, %c1_ui2 : !firrtl.uint, !firrtl.uint<2>
firrtl.connect %1, %c2_ui3 : !firrtl.uint, !firrtl.uint<3>
firrtl.connect %2, %c0_ui0 : !firrtl.uint, !firrtl.uint<0>
}
// see https://github.com/llvm/circt/issues/3070
@ -957,9 +955,7 @@ firrtl.circuit "Foo" {
firrtl.module @Property(in %a: !firrtl.string) { }
// CHECK-LABEL: module @MuxIntrinsics
// CHECK-SAME: %sel: !firrtl.uint<1>
// CHECK-SAME: %sel2: !firrtl.uint<2>
firrtl.module @MuxIntrinsics(in %sel: !firrtl.uint, in %sel2: !firrtl.uint, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
firrtl.module @MuxIntrinsics(in %sel_0w: !firrtl.uint<0>, in %sel_1w: !firrtl.uint<1>, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
%c3_ui4 = firrtl.constant 3 : !firrtl.uint<4>
%c3_ui3 = firrtl.constant 3 : !firrtl.uint<3>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
@ -967,12 +963,16 @@ firrtl.circuit "Foo" {
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
%c1 = firrtl.constant 0: !firrtl.uint
%sel = firrtl.wire : !firrtl.uint
firrtl.connect %sel, %sel_0w : !firrtl.uint, !firrtl.uint<0>
// CHECK: firrtl.int.mux2cell
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-SAME: (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%0 = firrtl.int.mux2cell(%sel, %c0_ui1, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint) -> !firrtl.uint
firrtl.connect %out1, %0: !firrtl.uint, !firrtl.uint
%sel2 = firrtl.wire : !firrtl.uint
firrtl.connect %sel2, %sel_1w : !firrtl.uint, !firrtl.uint<1>
// CHECK: firrtl.int.mux4cell
// CHECK-SAME: (!firrtl.uint<2>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
%1 = firrtl.int.mux4cell(%sel2, %c1_ui1, %c2_ui2, %c3_ui3, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint) -> !firrtl.uint
firrtl.connect %out2, %1: !firrtl.uint, !firrtl.uint
}