diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index 7a5cd3c179..0b57030226 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -47,6 +47,9 @@ def EqualTypes : Constraint($0.getType()).getWidth() == type_cast($1.getType()).getWidth()">>; +// Constraint that enforces equal type signedness +def EqualSigns : Constraint($0.getType()).isSigned() == type_cast($1.getType()).isSigned()">>; + // sizeof(0) >= sizeof(1) def IntTypeWidthGEQ32 : Constraint($0.getType()).getBitWidthOrSentinel() >= type_cast($1.getType()).getBitWidthOrSentinel()">>; @@ -684,6 +687,12 @@ def CatDoubleConst : Pat < (MoveNameHint $old, (CatPrimOp (CatPrimOp $cst1, (AsUIntPrimOp $cst2)), (AsUIntPrimOp $v))), [(KnownWidth $v), (AnyConstantOp $cst1), (AnyConstantOp $cst2)]>; +// cat(asUint(x), asUint(y)) -> cat(x,y) +def CatCast : Pat < + (CatPrimOp:$old (AsUIntPrimOp $x), (AsUIntPrimOp $y)), + (MoveNameHint $old, (CatPrimOp $x, $y)), + [(EqualSigns $x, $y)]>; + // regreset(clock, constant_zero, resetValue) -> reg(clock) def RegResetWithZeroReset : Pat< (RegResetOp $clock, $reset, $_, $name, $nameKind, $annotations, $inner_sym, $forceable), diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index 0c23cc228c..aa93ebd0b5 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -1201,7 +1201,7 @@ void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, namespace { // cat(bits(x, ...), bits(x, ...)) -> bits(x ...) when the two ...'s are -// consequtive in the input. +// consecutive in the input. struct CatBitsBits : public mlir::RewritePattern { CatBitsBits(MLIRContext *context) : RewritePattern(CatPrimOp::getOperationName(), 0, context) {} @@ -1228,7 +1228,8 @@ struct CatBitsBits : public mlir::RewritePattern { void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) { diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index e7b2fad560..735c4c8316 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -455,6 +455,7 @@ firrtl.module @Cat(in %in4: !firrtl.uint<4>, out %out4: !firrtl.uint<4>, out %outcst: !firrtl.uint<8>, out %outcst2: !firrtl.uint<8>, + out %outu8: !firrtl.uint<8>, in %in0 : !firrtl.uint<0>, out %outpt1: !firrtl.uint<4>, out %outpt2 : !firrtl.uint<4>) { @@ -490,6 +491,12 @@ firrtl.module @Cat(in %in4: !firrtl.uint<4>, %9 = firrtl.cat %c0_si2, %sin4 : (!firrtl.sint<2>, !firrtl.sint<4>) -> !firrtl.uint<6> %10 = firrtl.cat %c0_ui2, %9 : (!firrtl.uint<2>, !firrtl.uint<6>) -> !firrtl.uint<8> firrtl.connect %outcst, %10 : !firrtl.uint<8>, !firrtl.uint<8> + + // CHECK: %[[fixedsign:.*]] = firrtl.cat %sin4, %sin4 + // CHECK-NEXT: firrtl.strictconnect %outu8, %[[fixedsign]] + %tcast = firrtl.asUInt %sin4 : (!firrtl.sint<4>) -> !firrtl.uint<4> + %11 = firrtl.cat %tcast, %tcast : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<8> + firrtl.strictconnect %outu8, %11 : !firrtl.uint<8> } // CHECK-LABEL: firrtl.module @Bits