mirror of https://github.com/llvm/circt.git
[FIRRTL] no back-prop for width of mux selectors, support narrower (#6917)
* Allow mux selectors to be zero-width or 1-bit (for mux4). This is legal per FIRRTL spec. * InferWidths: mux no back-prop. Fixes https://github.com/llvm/circt/issues/5444 * canonicalizers for small mux selectors Co-authored-by: Schuyler Eldridge <schuyler.eldridge@gmail.com>
This commit is contained in:
parent
936806f0f5
commit
ea77394a85
|
@ -55,6 +55,12 @@ def IntTypeWidthGEQ32 : Constraint<CPred<
|
|||
def IntTypeWidthGT32 : Constraint<CPred<
|
||||
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() > type_cast<IntType>($1.getType()).getBitWidthOrSentinel()">>;
|
||||
|
||||
// sizeof(0) < X
|
||||
class IntTypeWidthLTX<int X> : Constraint<CPred<
|
||||
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() >= 0 &&"
|
||||
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel()< " # X
|
||||
>>;
|
||||
|
||||
// Constraint that enforces int types
|
||||
def IntTypes : Constraint<CPred<"type_isa<IntType>($0.getType())">>;
|
||||
|
||||
|
@ -562,6 +568,35 @@ def MuxNEQ : Pat<
|
|||
(MoveNameHint $old, (MuxPrimOp (EQPrimOp $a, $b), $y, $x)),
|
||||
[(EqualTypes $x, $y), (KnownWidth $x)]>;
|
||||
|
||||
// mux(cond : u0, a, b) -> mux(0 : u1, a, b)
|
||||
def MuxPadSel : Pat<
|
||||
(MuxPrimOp:$old $cond, $a, $b),
|
||||
(MoveNameHint $old, (MuxPrimOp
|
||||
(ConstantOp
|
||||
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
|
||||
(returnType "$_builder.getType<UIntType>(1)")),
|
||||
$a, $b)),
|
||||
[(IntTypeWidthLTX<1> $cond)]>;
|
||||
|
||||
// mux2(cond : u0, a, b) -> mux2(0 : u1, a, b)
|
||||
def Mux2PadSel : Pat<
|
||||
(Mux2CellIntrinsicOp:$old $cond, $a, $b),
|
||||
(MoveNameHint $old, (Mux2CellIntrinsicOp
|
||||
(ConstantOp
|
||||
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
|
||||
(returnType "$_builder.getType<UIntType>(1)")),
|
||||
$a, $b)),
|
||||
[(IntTypeWidthLTX<1> $cond)]>;
|
||||
|
||||
// mux4(cond : u0/u1, a, b) -> mux4(pad(cond -> u2), a, b)
|
||||
def Mux4PadSel : Pat<
|
||||
(Mux4CellIntrinsicOp:$old $cond, $a, $b, $c, $d),
|
||||
(MoveNameHint $old, (Mux4CellIntrinsicOp
|
||||
(PadPrimOp $cond,
|
||||
(NativeCodeCall<"$_builder.getI32IntegerAttr(2)">)),
|
||||
$a, $b, $c, $d)),
|
||||
[(IntTypeWidthLTX<2> $cond)]>;
|
||||
|
||||
def CatDoubleConst : Pat <
|
||||
(CatPrimOp:$old $cst1, (CatPrimOp $cst2, $v)),
|
||||
(MoveNameHint $old, (CatPrimOp (CatPrimOp $cst1, (AsUIntPrimOp $cst2)), (AsUIntPrimOp $v))),
|
||||
|
|
|
@ -756,7 +756,7 @@ def HeadPrimOp : PrimOp<"head"> {
|
|||
}
|
||||
|
||||
def MuxPrimOp : PrimOp<"mux"> {
|
||||
let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
|
||||
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
|
||||
PassiveType:$low);
|
||||
let results = (outs PassiveType:$result);
|
||||
|
||||
|
@ -842,10 +842,12 @@ def Mux2CellIntrinsicOp : PrimOp<"int.mux2cell"> {
|
|||
the inference process in the same way as a normal mux operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
|
||||
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
|
||||
PassiveType:$low);
|
||||
let results = (outs PassiveType:$result);
|
||||
|
||||
let hasCanonicalizer = true;
|
||||
|
||||
let assemblyFormat =
|
||||
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
@ -861,11 +863,13 @@ def Mux4CellIntrinsicOp : PrimOp<"int.mux4cell"> {
|
|||
the inference process as a sugar of mux operation chains.
|
||||
}];
|
||||
|
||||
let arguments = (ins UInt2OrUnsizedType:$sel, PassiveType:$v3,
|
||||
let arguments = (ins UIntLTE2OrUnsizedType:$sel, PassiveType:$v3,
|
||||
PassiveType:$v2, PassiveType:$v1,
|
||||
PassiveType:$v0);
|
||||
let results = (outs PassiveType:$result);
|
||||
|
||||
let hasCanonicalizer = true;
|
||||
|
||||
let assemblyFormat =
|
||||
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
|
|
@ -128,6 +128,12 @@ class SizedUIntType<int width> : FIRRTLDialectType<
|
|||
"type_cast<UIntType>($_self).getWidth() == " # width>,
|
||||
width # "-bit uint", "::circt::firrtl::UIntType">;
|
||||
|
||||
class SizedUIntTypeLTE<int width> : FIRRTLDialectType<
|
||||
CPred<"type_isa<UIntType>($_self) && "
|
||||
"type_cast<UIntType>($_self).getWidth() <= " # width>,
|
||||
"uint with width less than or equal to " # width # " bits",
|
||||
"::circt::firrtl::UIntType">;
|
||||
|
||||
class NonConstSizedUIntType<int width> :
|
||||
SizedUIntType<width>,
|
||||
BuildableType<
|
||||
|
@ -138,8 +144,8 @@ def UInt2Type : SizedUIntType<2>;
|
|||
def UInt32Type : SizedUIntType<32>;
|
||||
def NonConstUInt1Type : NonConstSizedUIntType<1>;
|
||||
|
||||
def UInt1OrUnsizedType : AnyTypeOf<[UInt1Type, UnsizedUIntType]>;
|
||||
def UInt2OrUnsizedType : AnyTypeOf<[UInt2Type, UnsizedUIntType]>;
|
||||
def UIntLTE1OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<1>, UnsizedUIntType]>;
|
||||
def UIntLTE2OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<2>, UnsizedUIntType]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FIRRTL Types Predicates
|
||||
|
|
|
@ -1476,10 +1476,22 @@ public:
|
|||
|
||||
void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
|
||||
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
|
||||
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
|
||||
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
|
||||
results
|
||||
.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
|
||||
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
|
||||
patterns::MuxSameTrue, patterns::MuxSameFalse,
|
||||
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
|
||||
context);
|
||||
}
|
||||
|
||||
void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<patterns::Mux2PadSel>(context);
|
||||
}
|
||||
|
||||
void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<patterns::Mux4PadSel>(context);
|
||||
}
|
||||
|
||||
OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
|
||||
|
|
|
@ -1596,12 +1596,12 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
|
|||
})
|
||||
.Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
|
||||
auto *sel = getExpr(op.getSel());
|
||||
constrainTypes(sel, solver.known(1));
|
||||
constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
|
||||
maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
|
||||
})
|
||||
.Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
|
||||
auto *sel = getExpr(op.getSel());
|
||||
constrainTypes(sel, solver.known(2));
|
||||
constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
|
||||
maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
|
||||
maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
|
||||
maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
|
||||
|
|
|
@ -506,7 +506,9 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
|
|||
out %out1: !firrtl.uint<1>,
|
||||
out %out2: !firrtl.uint<0>,
|
||||
out %out3: !firrtl.uint<1>,
|
||||
out %out4: !firrtl.uint<4>) {
|
||||
out %out4: !firrtl.uint<4>,
|
||||
out %out5: !firrtl.uint<1>,
|
||||
out %out6: !firrtl.uint<1>) {
|
||||
// CHECK: firrtl.strictconnect %out, %in
|
||||
%0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
|
||||
firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4>
|
||||
|
@ -560,6 +562,15 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
|
|||
// CHECK-NEXT: [[V2:%.+]] = firrtl.mux(%cond
|
||||
// CHECK-NEXT: firrtl.strictconnect %out4, [[V2]]
|
||||
firrtl.connect %out4, %15 : !firrtl.uint<4>, !firrtl.uint<4>
|
||||
|
||||
// CHECK-NEXT: firrtl.strictconnect %out5, %val2
|
||||
%16 = firrtl.mux (%val0, %val1, %val2) : (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
firrtl.strictconnect %out5, %16 : !firrtl.uint<1>
|
||||
|
||||
// CHECK-NEXT: %[[SEL:.+]] = firrtl.pad %val1, 2 : (!firrtl.uint<1>) -> !firrtl.uint<2>
|
||||
// CHECK-NEXT: mux4cell(%[[SEL]],
|
||||
%17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
firrtl.strictconnect %out6, %17 : !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: firrtl.module @Pad
|
||||
|
|
|
@ -186,3 +186,29 @@ firrtl.circuit "NoWidthEnum" {
|
|||
firrtl.module @NoWidthEnum(out %o: !firrtl.enum<Some: uint>) {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
firrtl.circuit "MuxSelBackProp" {
|
||||
firrtl.module @MuxSelBackProp() {
|
||||
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
|
||||
// expected-error @below {{uninferred width: wire is unconstrained}}
|
||||
%0 = firrtl.wire : !firrtl.uint
|
||||
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
firrtl.circuit "MuxSelTooWide" {
|
||||
firrtl.module @MuxSelTooWide() {
|
||||
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
|
||||
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
|
||||
// expected-error @below {{uninferred width: wire cannot satisfy all width requirements}}
|
||||
%0 = firrtl.wire : !firrtl.uint
|
||||
// expected-note @below {{width is constrained to be at most 1 here:}}
|
||||
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
// expected-note @below {{width is constrained to be at least 2 here:}}
|
||||
firrtl.connect %0, %c2_ui2 : !firrtl.uint, !firrtl.uint<2>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -364,20 +364,18 @@ firrtl.circuit "Foo" {
|
|||
firrtl.module @MuxOp() {
|
||||
// CHECK: %0 = firrtl.wire : !firrtl.uint<2>
|
||||
// CHECK: %1 = firrtl.wire : !firrtl.uint<3>
|
||||
// CHECK: %2 = firrtl.wire : !firrtl.uint<1>
|
||||
// CHECK: %2 = firrtl.wire : !firrtl.uint<0>
|
||||
// CHECK: %3 = firrtl.mux{{.*}} -> !firrtl.uint<3>
|
||||
%0 = firrtl.wire : !firrtl.uint
|
||||
%1 = firrtl.wire : !firrtl.uint
|
||||
%2 = firrtl.wire : !firrtl.uint
|
||||
%3 = firrtl.mux(%2, %0, %1) : (!firrtl.uint, !firrtl.uint, !firrtl.uint) -> !firrtl.uint
|
||||
// CHECK: %4 = firrtl.wire : !firrtl.uint<1>
|
||||
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
|
||||
%4 = firrtl.wire : !firrtl.uint
|
||||
%5 = firrtl.mux(%4, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
|
||||
%c2_ui3 = firrtl.constant 2 : !firrtl.uint<3>
|
||||
%c0_ui0 = firrtl.constant 0 : !firrtl.uint<0>
|
||||
firrtl.connect %0, %c1_ui2 : !firrtl.uint, !firrtl.uint<2>
|
||||
firrtl.connect %1, %c2_ui3 : !firrtl.uint, !firrtl.uint<3>
|
||||
firrtl.connect %2, %c0_ui0 : !firrtl.uint, !firrtl.uint<0>
|
||||
}
|
||||
|
||||
// see https://github.com/llvm/circt/issues/3070
|
||||
|
@ -957,9 +955,7 @@ firrtl.circuit "Foo" {
|
|||
firrtl.module @Property(in %a: !firrtl.string) { }
|
||||
|
||||
// CHECK-LABEL: module @MuxIntrinsics
|
||||
// CHECK-SAME: %sel: !firrtl.uint<1>
|
||||
// CHECK-SAME: %sel2: !firrtl.uint<2>
|
||||
firrtl.module @MuxIntrinsics(in %sel: !firrtl.uint, in %sel2: !firrtl.uint, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
|
||||
firrtl.module @MuxIntrinsics(in %sel_0w: !firrtl.uint<0>, in %sel_1w: !firrtl.uint<1>, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
|
||||
%c3_ui4 = firrtl.constant 3 : !firrtl.uint<4>
|
||||
%c3_ui3 = firrtl.constant 3 : !firrtl.uint<3>
|
||||
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
|
||||
|
@ -967,12 +963,16 @@ firrtl.circuit "Foo" {
|
|||
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
|
||||
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
|
||||
%c1 = firrtl.constant 0: !firrtl.uint
|
||||
%sel = firrtl.wire : !firrtl.uint
|
||||
firrtl.connect %sel, %sel_0w : !firrtl.uint, !firrtl.uint<0>
|
||||
// CHECK: firrtl.int.mux2cell
|
||||
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
// CHECK-SAME: (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
%0 = firrtl.int.mux2cell(%sel, %c0_ui1, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint) -> !firrtl.uint
|
||||
firrtl.connect %out1, %0: !firrtl.uint, !firrtl.uint
|
||||
%sel2 = firrtl.wire : !firrtl.uint
|
||||
firrtl.connect %sel2, %sel_1w : !firrtl.uint, !firrtl.uint<1>
|
||||
// CHECK: firrtl.int.mux4cell
|
||||
// CHECK-SAME: (!firrtl.uint<2>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
|
||||
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
|
||||
%1 = firrtl.int.mux4cell(%sel2, %c1_ui1, %c2_ui2, %c3_ui3, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint) -> !firrtl.uint
|
||||
firrtl.connect %out2, %1: !firrtl.uint, !firrtl.uint
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue