diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index d72e5a4a1c..3e3d7b1585 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -64,16 +64,9 @@ def ZeroConstantOp : Constraint().value() == false"> ]>>; -def IsInvalid : Constraint< - CPred<"$0.getDefiningOp()"> ->; - def GetEmptyString : NativeCodeCall< "StringAttr::get($_builder.getContext(), {}) ">; -def GetZeroConstant : NativeCodeCall< - "$_builder.create($0.getLoc(), APSInt($0.getType().cast().getBitWidthOrSentinel()))">; - // leq(const, x) -> geq(x, const) def LEQWithConstLHS : Pat< (LEQPrimOp $lhs, $rhs), @@ -142,33 +135,6 @@ def GetWidthAsIntAttr : NativeCodeCall< "IntegerAttr::get(IntegerType::get($_builder.getContext(), 32, IntegerType::Signless), " "$0.getType().cast().getBitWidthOrSentinel())">; -// add(x, invalid) -> pad(x, width) -// -// This is legal because it aligns with the Scala FIRRTL Compiler -// interpretation of lowering invalid to constant zero before constant -// propagation. -def AddWithInvalidOp : Pat< - (AddPrimOp:$result $x, $y), - (PadPrimOp $x, (GetWidthAsIntAttr $result)), [ - (KnownWidth $x), (IsInvalid $y) - ]>; - -// sub(x, invalid) -> pad(x, width) -// -// This is legal because it aligns with the Scala FIRRTL Compiler -// interpretation of lowering invalid to constant zero before constant -// propagation. -def SubWithInvalidOp : Pat< - (SubPrimOp:$result $x, $y), - (PadPrimOp $x, (GetWidthAsIntAttr $result)), [ - (KnownWidth $x), (IsInvalid $y) - ]>; - -// pad(invalid, width) -> zero -def PadInvalid : Pat< - (PadPrimOp:$result (InvalidValueOp), $_), - (GetZeroConstant $result), []>; - //////////////////////////////////////////////////////////////////////////////// // DontTouch application //////////////////////////////////////////////////////////////////////////////// diff --git a/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td b/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td index 9ba9b891b2..2dd84f861e 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td @@ -296,7 +296,7 @@ class IntBinaryPrimOp; let inferType = "impl::inferAddSubResult" in { - let hasCanonicalizer = true in { + let hasCanonicalizeMethod = true in { def AddPrimOp : IntBinaryPrimOp<"add", IntType, [Commutative]>; def SubPrimOp : IntBinaryPrimOp<"sub", IntType>; } @@ -480,7 +480,6 @@ def PadPrimOp : PrimOp<"pad"> { }]; let parseValidator = "impl::validateOneOperandOneConst"; - let hasCanonicalizer = true; } class ShiftPrimOp : PrimOp { diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index dd15c07992..09680e7468 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" // Forward Decl for patterns. static bool isUselessName(circt::StringRef name); @@ -93,11 +94,12 @@ static bool isUselessName(StringRef name) { } /// Implicitly replace the operand to a constant folding operation with a const -/// 0 in case the operand is non-constant but has a bit width 0. +/// 0 in case the operand is non-constant but has a bit width 0, or if the +/// operand is an invalid value. /// /// This makes constant folding significantly easier, as we can simply pass the /// operands to an operation through this function to appropriately replace any -/// zero-width dynamic values with a constant of value 0. +/// zero-width dynamic values or invalid values with a constant of value 0. static Optional getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) { assert(operand.getType().isa() && @@ -107,6 +109,10 @@ static Optional getExtendedConstant(Value operand, Attribute constant, if (destWidth < 0) return {}; + // InvalidValue inputs simply read as zero. + if (auto result = constant.dyn_cast_or_null()) + return APSInt(destWidth, operand.getType().cast().isUnsigned()); + // Extension signedness follows the operand sign. if (IntegerAttr result = constant.dyn_cast_or_null()) return extOrTruncZeroWidth(result.getAPSInt(), destWidth); @@ -118,6 +124,34 @@ static Optional getExtendedConstant(Value operand, Attribute constant, return {}; } +/// Determine the value of a constant operand for the sake of constant folding. +/// This will map `invalidvalue` to a zero value of the corresopnding type, +/// which aligns with how the Scala FIRRTL compiler handles invalids in most +/// cases. For a full discussion of this see the FIRRTL Rationale document. +static Optional getConstant(Attribute operand) { + if (!operand) + return {}; + if (auto attr = operand.dyn_cast()) { + if (auto type = attr.getType().dyn_cast()) + return APSInt(type.getWidth().getValueOr(1), type.isUnsigned()); + if (attr.getType().isa()) + return APSInt(1); + } + if (auto attr = operand.dyn_cast()) + return attr.getAPSInt(); + if (auto attr = operand.dyn_cast()) + return APSInt(APInt(1, attr.getValue())); + return {}; +} + +/// Determine whether a constant operand is a zero value for the sake of +/// constant folding. This considers `invalidvalue` to be zero. +static bool isConstantZero(Attribute operand) { + if (auto cst = getConstant(operand)) + return cst->isZero(); + return false; +} + /// This is the policy for folding, which depends on the sort of operator we're /// processing. enum class BinOpKind { @@ -192,6 +226,63 @@ constFoldFIRRTLBinaryOp(Operation *op, ArrayRef operands, return getIntAttr(resultType, resultValue); } +/// Applies the canonicalization function `canonicalize` to the given operation. +/// +/// Determines which (if any) of the operation's operands are constants, and +/// provides them as arguments to the callback function. Any `invalidvalue` in +/// the input is mapped to a constant zero. The value returned from the callback +/// is used as the replacement for `op`, and an additional pad operation is +/// inserted if necessary. Does nothing if the result of `op` is of unknown +/// width, in which case the necessity of a pad cannot be determined. +static LogicalResult canonicalizePrimOp( + Operation *op, PatternRewriter &rewriter, + const function_ref)> &canonicalize) { + // Can only operate on FIRRTL primitive operations. + if (op->getNumResults() != 1) + return failure(); + auto type = op->getResult(0).getType().dyn_cast(); + if (!type) + return failure(); + + // Can only operate on operations with a known result width. + auto width = type.getBitWidthOrSentinel(); + if (width < 0) + return failure(); + + // Determine which of the operands are constants. + SmallVector constOperands; + constOperands.reserve(op->getNumOperands()); + for (auto operand : op->getOperands()) { + Attribute attr; + if (auto *defOp = operand.getDefiningOp()) + TypeSwitch(defOp) + .Case( + [&](auto op) { attr = op.fold({}).template get(); }); + constOperands.push_back(attr); + } + + // Perform the canonicalization and materialize the result if it is a + // constant. + auto result = canonicalize(constOperands); + if (!result) + return failure(); + Value resultValue; + if (auto cst = result.dyn_cast()) + resultValue = op->getDialect() + ->materializeConstant(rewriter, cst, type, op->getLoc()) + ->getResult(0); + else + resultValue = result.get(); + + // Insert a pad if the type widths disagree. + if (width != resultValue.getType().cast().getBitWidthOrSentinel()) + resultValue = rewriter.create(op->getLoc(), resultValue, width); + + assert(type == resultValue.getType() && "canonicalization changed type"); + rewriter.replaceOp(op, resultValue); + return success(); +} + /// Get the largest unsigned value of a given bit width. Returns a 1-bit zero /// value if `bitWidth` is 0. static APInt getMaxUnsignedValue(unsigned bitWidth) { @@ -237,9 +328,17 @@ OpFoldResult AddPrimOp::fold(ArrayRef operands) { [=](APSInt a, APSInt b) { return a + b; }); } -void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert(context); +LogicalResult AddPrimOp::canonicalize(AddPrimOp op, PatternRewriter &rewriter) { + return canonicalizePrimOp(op, rewriter, + [&](ArrayRef operands) -> OpFoldResult { + // add(x, 0) -> x + if (isConstantZero(operands[1])) + return op.getOperand(0); + // add(0, x) -> x + if (isConstantZero(operands[0])) + return op.getOperand(1); + return {}; + }); } OpFoldResult SubPrimOp::fold(ArrayRef operands) { @@ -247,20 +346,33 @@ OpFoldResult SubPrimOp::fold(ArrayRef operands) { [=](APSInt a, APSInt b) { return a - b; }); } -void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert(context); +LogicalResult SubPrimOp::canonicalize(SubPrimOp op, PatternRewriter &rewriter) { + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + // sub(x, 0) -> x + if (isConstantZero(operands[1])) + return op.getOperand(0); + // sub(0, x) -> neg(x) if x is signed + // sub(0, x) -> asUInt(neg(x)) if x is unsigned + if (isConstantZero(operands[0])) { + Value value = + rewriter.create(op.getLoc(), op.getOperand(1)); + if (op.getType().isa()) + value = rewriter.create(op.getLoc(), value); + return value; + } + return {}; + }); } OpFoldResult MulPrimOp::fold(ArrayRef operands) { - // mul(x, invalid) -> 0 + // mul(x, 0) -> 0 // // This is legal because it aligns with the Scala FIRRTL Compiler // interpretation of lowering invalid to constant zero before constant // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize // multiplication this way and will emit "x * 0". - if (operands[1].dyn_cast_or_null() || - operands[0].dyn_cast_or_null()) + if (isConstantZero(operands[1]) || isConstantZero(operands[0])) return getIntZerosAttr(getType()); return constFoldFIRRTLBinaryOp(*this, operands, BinOpKind::Normal, @@ -283,14 +395,13 @@ OpFoldResult DivPrimOp::fold(ArrayRef operands) { return getIntAttr(getType(), APInt(width, 1)); } - // div(invalid, x) -> 0 + // div(0, x) -> 0 // // This is legal because it aligns with the Scala FIRRTL Compiler // interpretation of lowering invalid to constant zero before constant // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize // division this way and will emit "0 / x". - if (operands[0].dyn_cast_or_null() && - !operands[1].dyn_cast_or_null()) + if (isConstantZero(operands[0]) && !isConstantZero(operands[1])) return getIntZerosAttr(getType()); /// div(x, 1) -> x : (uint, uint) -> uint @@ -311,14 +422,22 @@ OpFoldResult DivPrimOp::fold(ArrayRef operands) { } OpFoldResult RemPrimOp::fold(ArrayRef operands) { - // rem(invalid, x) -> 0 + // rem(x, x) -> 0 + // + // Division by zero is undefined in the FIRRTL specification. This fold + // exploits that fact to optimize self division remainder to zero. Note: + // this should supersede any division with invalid or zero. Remainder of + // division of invalid by invalid should be zero. + if (lhs() == rhs()) + return getIntZerosAttr(getType()); + + // rem(0, x) -> 0 // // This is legal because it aligns with the Scala FIRRTL Compiler // interpretation of lowering invalid to constant zero before constant // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize // division this way and will emit "0 % x". - if (operands[0].dyn_cast_or_null() && - !operands[1].dyn_cast_or_null()) + if (isConstantZero(operands[0])) return getIntZerosAttr(getType()); return constFoldFIRRTLBinaryOp(*this, operands, BinOpKind::DivideOrShift, @@ -353,22 +472,13 @@ OpFoldResult DShrPrimOp::fold(ArrayRef operands) { // TODO: Move to DRR. OpFoldResult AndPrimOp::fold(ArrayRef operands) { - // and(x, invalid) -> 0 - // - // This is legal because it aligns with the Scala FIRRTL Compiler - // interpretation of lowering invalid to constant zero before constant - // propagation. - if (operands[1].dyn_cast_or_null() || - operands[0].dyn_cast_or_null()) - return getIntZerosAttr(getType()); - - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { /// and(x, 0) -> 0 - if (rhsCst.getValue().isZero() && rhs().getType() == getType()) - return rhs(); + if (rhsCst->isZero() && rhs().getType() == getType()) + return getIntZerosAttr(getType()); /// and(x, -1) -> x - if (rhsCst.getValue().isAllOnes() && lhs().getType() == getType() && + if (rhsCst->isAllOnes() && lhs().getType() == getType() && rhs().getType() == getType()) return lhs(); } @@ -383,20 +493,7 @@ OpFoldResult AndPrimOp::fold(ArrayRef operands) { } OpFoldResult OrPrimOp::fold(ArrayRef operands) { - // or(x, invalid) -> x - // or(invalid, x) -> x - // - // This is legal because it aligns with the Scala FIRRTL Compiler - // interpretation of lowering invalid to constant zero before constant - // propagation. - if (operands[0].dyn_cast_or_null() && - rhs().getType() == getType()) - return rhs(); - if (operands[1].dyn_cast_or_null() && - lhs().getType() == getType()) - return lhs(); - - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { /// or(x, 0) -> x if (rhsCst.getValue().isZero() && lhs().getType() == getType()) return lhs(); @@ -417,21 +514,8 @@ OpFoldResult OrPrimOp::fold(ArrayRef operands) { } OpFoldResult XorPrimOp::fold(ArrayRef operands) { - // xor(x, invalid) -> x - // xor(invalid, x) -> x - // - // This is legal because it aligns with the Scala FIRRTL Compiler - // interpretation of lowering invalid to constant zero before constant - // propagation. - if (operands[0].dyn_cast_or_null() && - rhs().getType() == getType()) - return rhs(); - if (operands[1].dyn_cast_or_null() && - lhs().getType() == getType()) - return lhs(); - /// xor(x, 0) -> x - if (auto rhsCst = operands[1].dyn_cast_or_null()) + if (auto rhsCst = getConstant(operands[1])) if (rhsCst.getValue().isZero() && lhs().getType() == getType()) return lhs(); @@ -461,7 +545,7 @@ OpFoldResult LEQPrimOp::fold(ArrayRef operands) { // Comparison against constant outside type bounds. if (auto width = lhs().getType().cast().getWidth()) { - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { auto commonWidth = std::max(*width, rhsCst.getValue().getBitWidth()); commonWidth = std::max(commonWidth, 0); @@ -508,14 +592,14 @@ OpFoldResult LTPrimOp::fold(ArrayRef operands) { return getIntAttr(getType(), APInt(1, 0)); // lt(x, 0) -> 0 when x is unsigned - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { if (rhsCst.getValue().isZero() && lhs().getType().isa()) return getIntAttr(getType(), APInt(1, 0)); } // Comparison against constant outside type bounds. if (auto width = lhs().getType().cast().getWidth()) { - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { auto commonWidth = std::max(*width, rhsCst.getValue().getBitWidth()); commonWidth = std::max(commonWidth, 0); @@ -561,14 +645,14 @@ OpFoldResult GEQPrimOp::fold(ArrayRef operands) { return getIntAttr(getType(), APInt(1, 1)); // geq(x, 0) -> 1 when x is unsigned - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { if (rhsCst.getValue().isZero() && lhs().getType().isa()) return getIntAttr(getType(), APInt(1, 1)); } // Comparison against constant outside type bounds. if (auto width = lhs().getType().cast().getWidth()) { - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { auto commonWidth = std::max(*width, rhsCst.getValue().getBitWidth()); commonWidth = std::max(commonWidth, 0); @@ -616,7 +700,7 @@ OpFoldResult GTPrimOp::fold(ArrayRef operands) { // Comparison against constant outside type bounds. if (auto width = lhs().getType().cast().getWidth()) { - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { auto commonWidth = std::max(*width, rhsCst.getValue().getBitWidth()); commonWidth = std::max(commonWidth, 0); @@ -655,7 +739,7 @@ OpFoldResult EQPrimOp::fold(ArrayRef operands) { if (lhs() == rhs()) return getIntAttr(getType(), APInt(1, 1)); - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { /// eq(x, 1) -> x when x is 1 bit. /// TODO: Support SInt<1> on the LHS etc. if (rhsCst.getValue().isAllOnes() && lhs().getType() == getType() && @@ -669,33 +753,35 @@ OpFoldResult EQPrimOp::fold(ArrayRef operands) { } LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) { + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + if (auto rhsCst = getConstant(operands[1])) { + auto width = + op.lhs().getType().cast().getBitWidthOrSentinel(); - if (auto rhsCst = dyn_cast_or_null(op.rhs().getDefiningOp())) { - auto width = op.lhs().getType().cast().getBitWidthOrSentinel(); + // eq(x, 0) -> not(x) when x is 1 bit. + if (rhsCst->isZero() && op.lhs().getType() == op.getType() && + op.rhs().getType() == op.getType()) { + return rewriter.create(op.getLoc(), op.lhs()) + .getResult(); + } - // eq(x, 0) -> not(x) when x is 1 bit. - if (rhsCst.value().isZero() && op.lhs().getType() == op.getType() && - op.rhs().getType() == op.getType()) { - rewriter.replaceOpWithNewOp(op, op.lhs()); - return success(); - } + // eq(x, 0) -> not(orr(x)) when x is >1 bit + if (rhsCst->isZero() && width > 1) { + auto orrOp = rewriter.create(op.getLoc(), op.lhs()); + return rewriter.create(op.getLoc(), orrOp).getResult(); + } - // eq(x, 0) -> not(orr(x)) when x is >1 bit - if (rhsCst.value().isZero() && width > 1) { - auto orrOp = rewriter.create(op.getLoc(), op.lhs()); - rewriter.replaceOpWithNewOp(op, orrOp); - return success(); - } + // eq(x, ~0) -> andr(x) when x is >1 bit + if (rhsCst->isAllOnes() && width > 1 && + op.lhs().getType() == op.rhs().getType()) { + return rewriter.create(op.getLoc(), op.lhs()) + .getResult(); + } + } - // eq(x, ~0) -> andr(x) when x is >1 bit - if (rhsCst.value().isAllOnes() && width > 1 && - op.lhs().getType() == op.rhs().getType()) { - rewriter.replaceOpWithNewOp(op, op.lhs()); - return success(); - } - } - - return failure(); + return {}; + }); } OpFoldResult NEQPrimOp::fold(ArrayRef operands) { @@ -703,7 +789,7 @@ OpFoldResult NEQPrimOp::fold(ArrayRef operands) { if (lhs() == rhs()) return getIntAttr(getType(), APInt(1, 0)); - if (auto rhsCst = operands[1].dyn_cast_or_null()) { + if (auto rhsCst = getConstant(operands[1])) { /// neq(x, 0) -> x when x is 1 bit. /// TODO: Support SInt<1> on the LHS etc. if (rhsCst.getValue().isZero() && lhs().getType() == getType() && @@ -717,31 +803,35 @@ OpFoldResult NEQPrimOp::fold(ArrayRef operands) { } LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) { - if (auto rhsCst = dyn_cast_or_null(op.rhs().getDefiningOp())) { - auto width = op.lhs().getType().cast().getBitWidthOrSentinel(); - // neq(x, 1) -> not(x) when x is 1 bit - if (rhsCst.value().isAllOnes() && op.lhs().getType() == op.getType() && - op.rhs().getType() == op.getType()) { - rewriter.replaceOpWithNewOp(op, op.lhs()); - return success(); - } + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + if (auto rhsCst = getConstant(operands[1])) { + auto width = + op.lhs().getType().cast().getBitWidthOrSentinel(); - // neq(x, 0) -> orr(x) when x is >1 bit - if (rhsCst.value().isZero() && width > 1) { - rewriter.replaceOpWithNewOp(op, op.lhs()); - return success(); - } + // neq(x, 1) -> not(x) when x is 1 bit + if (rhsCst->isAllOnes() && op.lhs().getType() == op.getType() && + op.rhs().getType() == op.getType()) { + return rewriter.create(op.getLoc(), op.lhs()) + .getResult(); + } - // neq(x, ~0) -> not(andr(x))) when x is >1 bit - if (rhsCst.value().isAllOnes() && width > 1 && - op.lhs().getType() == op.rhs().getType()) { - auto andrOp = rewriter.create(op.getLoc(), op.lhs()); - rewriter.replaceOpWithNewOp(op, andrOp); - return success(); - } - } + // neq(x, 0) -> orr(x) when x is >1 bit + if (rhsCst->isZero() && width > 1) { + return rewriter.create(op.getLoc(), op.lhs()) + .getResult(); + } - return failure(); + // neq(x, ~0) -> not(andr(x))) when x is >1 bit + if (rhsCst->isAllOnes() && width > 1 && + op.lhs().getType() == op.rhs().getType()) { + auto andrOp = rewriter.create(op.getLoc(), op.lhs()); + return rewriter.create(op.getLoc(), andrOp).getResult(); + } + } + + return {}; + }); } //===----------------------------------------------------------------------===// @@ -753,19 +843,12 @@ OpFoldResult AsSIntPrimOp::fold(ArrayRef operands) { if (input().getType() == getType()) return input(); - if (!operands[0]) - return {}; - - // Constant clocks and resets are bool attributes. - if (auto attr = operands[0].dyn_cast()) - return getIntAttr(getType(), APInt(/*bitWidth*/ 1, attr.getValue())); - // Be careful to only fold the cast into the constant if the size is known. // Otherwise width inference may produce differently-sized constants if the // sign changes. - if (auto attr = operands[0].dyn_cast()) - if (getType().hasWidth()) - return getIntAttr(getType(), attr.getValue()); + if (getType().hasWidth()) + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), *cst); return {}; } @@ -775,19 +858,12 @@ OpFoldResult AsUIntPrimOp::fold(ArrayRef operands) { if (input().getType() == getType()) return input(); - if (!operands[0]) - return {}; - - // Constant clocks and resets are bool attributes. - if (auto attr = operands[0].dyn_cast()) - return getIntAttr(getType(), APInt(/*bitWidth*/ 1, attr.getValue())); - // Be careful to only fold the cast into the constant if the size is known. // Otherwise width inference may produce differently-sized constants if the // sign changes. - if (auto attr = operands[0].dyn_cast()) - if (getType().hasWidth()) - return getIntAttr(getType(), attr.getValue()); + if (getType().hasWidth()) + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), *cst); return {}; } @@ -798,8 +874,8 @@ OpFoldResult AsAsyncResetPrimOp::fold(ArrayRef operands) { return input(); // Constant fold. - if (auto attr = operands[0].dyn_cast_or_null()) - return BoolAttr::get(getContext(), attr.getValue().getBoolValue()); + if (auto cst = getConstant(operands[0])) + return BoolAttr::get(getContext(), cst->getBoolValue()); return {}; } @@ -810,8 +886,8 @@ OpFoldResult AsClockPrimOp::fold(ArrayRef operands) { return input(); // Constant fold. - if (auto attr = operands[0].dyn_cast_or_null()) - return BoolAttr::get(getContext(), attr.getValue().getBoolValue()); + if (auto cst = getConstant(operands[0])) + return BoolAttr::get(getContext(), cst->getBoolValue()); return {}; } @@ -835,9 +911,8 @@ OpFoldResult NegPrimOp::fold(ArrayRef operands) { // FIRRTL negate always adds a bit. // -x ---> 0-sext(x) or 0-zext(x) - auto cst = getExtendedConstant(getOperand(), operands[0], - getType().getWidthOrSentinel()); - if (cst.hasValue()) + if (auto cst = getExtendedConstant(getOperand(), operands[0], + getType().getWidthOrSentinel())) return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst); return {}; @@ -847,8 +922,9 @@ OpFoldResult NotPrimOp::fold(ArrayRef operands) { if (!hasKnownWidthIntTypes(*this)) return {}; - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr(getType(), ~attr.getValue()); + if (auto cst = getExtendedConstant(getOperand(), operands[0], + getType().getWidthOrSentinel())) + return getIntAttr(getType(), ~*cst); return {}; } @@ -858,8 +934,8 @@ OpFoldResult AndRPrimOp::fold(ArrayRef operands) { return {}; // x == -1 - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr(getType(), APInt(1, attr.getValue().isAllOnes())); + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), APInt(1, cst->isAllOnes())); // one bit is identity. Only applies to UInt since we can't make a cast // here. @@ -874,8 +950,8 @@ OpFoldResult OrRPrimOp::fold(ArrayRef operands) { return {}; // x != 0 - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr(getType(), APInt(1, !attr.getValue().isZero())); + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), APInt(1, !cst->isZero())); // one bit is identity. Only applies to UInt since we can't make a cast // here. @@ -890,9 +966,8 @@ OpFoldResult XorRPrimOp::fold(ArrayRef operands) { return {}; // popcount(x) & 1 - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr(getType(), - APInt(1, attr.getValue().countPopulation() & 1)); + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), APInt(1, cst->countPopulation() & 1)); // one bit is identity. Only applies to UInt since we can't make a cast here. if (isUInt1(input().getType())) @@ -910,12 +985,11 @@ OpFoldResult CatPrimOp::fold(ArrayRef operands) { return {}; // Constant fold cat. - if (auto lhs = operands[0].dyn_cast_or_null()) - if (auto rhs = operands[1].dyn_cast_or_null()) { + if (auto lhs = getConstant(operands[0])) + if (auto rhs = getConstant(operands[1])) { auto destWidth = getType().getWidthOrSentinel(); - APInt tmp1 = lhs.getValue().zextOrSelf(destWidth) - << rhs.getValue().getBitWidth(); - APInt tmp2 = rhs.getValue().zextOrSelf(destWidth); + APInt tmp1 = lhs->zextOrSelf(destWidth) << rhs->getBitWidth(); + APInt tmp2 = rhs->zextOrSelf(destWidth); return getIntAttr(getType(), tmp1 | tmp2); } @@ -927,18 +1001,19 @@ LogicalResult DShlPrimOp::canonicalize(DShlPrimOp op, if (!hasKnownWidthIntTypes(op)) return failure(); - // dshl(x, cst) -> shl(x, cst). The result size is generally much wider than - // what is needed for the constant. - if (auto rhsCst = dyn_cast_or_null(op.rhs().getDefiningOp())) { - // Shift amounts are always unsigned, but shift only takes a 32-bit amount. - uint64_t shiftAmt = rhsCst.value().getLimitedValue(1ULL << 31); - auto result = - rewriter.createOrFold(op.getLoc(), op.lhs(), shiftAmt); - rewriter.replaceOpWithNewOp( - op, result, op.getType().cast().getWidthOrSentinel()); - return success(); - } - return failure(); + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + // dshl(x, cst) -> shl(x, cst). The result size is generally much wider + // than what is needed for the constant. + if (auto rhsCst = getConstant(operands[1])) { + // Shift amounts are always unsigned, but shift only takes a 32-bit + // amount. + uint64_t shiftAmt = rhsCst->getLimitedValue(1ULL << 31); + return rewriter.createOrFold(op.getLoc(), op.lhs(), + shiftAmt); + } + return {}; + }); } LogicalResult DShrPrimOp::canonicalize(DShrPrimOp op, @@ -946,18 +1021,19 @@ LogicalResult DShrPrimOp::canonicalize(DShrPrimOp op, if (!hasKnownWidthIntTypes(op)) return failure(); - // dshr(x, cst) -> shr(x, cst). The result size is generally much wider than - // what is needed for the constant. - if (auto rhsCst = dyn_cast_or_null(op.rhs().getDefiningOp())) { - // Shift amounts are always unsigned, but shift only takes a 32-bit amount. - uint64_t shiftAmt = rhsCst.value().getLimitedValue(1ULL << 31); - auto result = - rewriter.createOrFold(op.getLoc(), op.lhs(), shiftAmt); - rewriter.replaceOpWithNewOp( - op, result, op.getType().cast().getWidthOrSentinel()); - return success(); - } - return failure(); + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + // dshr(x, cst) -> shr(x, cst). The result size is generally much wider + // than what is needed for the constant. + if (auto rhsCst = getConstant(operands[1])) { + // Shift amounts are always unsigned, but shift only takes a 32-bit + // amount. + uint64_t shiftAmt = rhsCst->getLimitedValue(1ULL << 31); + return rewriter.createOrFold(op.getLoc(), op.lhs(), + shiftAmt); + } + return {}; + }); } LogicalResult CatPrimOp::canonicalize(CatPrimOp op, PatternRewriter &rewriter) { @@ -999,9 +1075,9 @@ OpFoldResult BitsPrimOp::fold(ArrayRef operands) { // Constant fold. if (hasKnownWidthIntTypes(*this)) - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr( - getType(), attr.getValue().lshr(lo()).truncOrSelf(hi() - lo() + 1)); + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), + cst->lshr(lo()).truncOrSelf(hi() - lo() + 1)); return {}; } @@ -1061,23 +1137,23 @@ OpFoldResult MuxPrimOp::fold(ArrayRef operands) { return {}; // mux(0/1, x, y) -> x or y - if (auto cond = operands[0].dyn_cast_or_null()) { - if (cond.getValue().isZero() && low().getType() == getType()) + if (auto cond = getConstant(operands[0])) { + if (cond->isZero() && low().getType() == getType()) return low(); - if (!cond.getValue().isZero() && high().getType() == getType()) + if (!cond->isZero() && high().getType() == getType()) return high(); } // mux(cond, x, cst) - if (auto lowCst = operands[2].dyn_cast_or_null()) { + if (auto lowCst = getConstant(operands[2])) { // mux(cond, c1, c2) - if (auto highCst = operands[1].dyn_cast_or_null()) { - if (highCst.getType() == lowCst.getType() && - highCst.getValue() == lowCst.getValue()) - return highCst; + if (auto highCst = getConstant(operands[1])) { + // mux(cond, cst, cst) -> cst + if (highCst->getBitWidth() == lowCst->getBitWidth() && + *highCst == *lowCst) + return getIntAttr(getType(), *highCst); // mux(cond, 1, 0) -> cond - if (highCst.getValue().isOne() && lowCst.getValue().isZero() && - getType() == sel().getType()) + if (highCst->isOne() && lowCst->isZero() && getType() == sel().getType()) return sel(); // TODO: x ? ~0 : 0 -> sext(x) @@ -1138,24 +1214,19 @@ OpFoldResult PadPrimOp::fold(ArrayRef operands) { return {}; // Constant fold. - if (auto cst = operands[0].dyn_cast_or_null()) { + if (auto cst = getConstant(operands[0])) { auto destWidth = getType().getWidthOrSentinel(); if (destWidth == -1) return {}; - if (inputType.isSigned() && cst.getValue().getBitWidth()) - return getIntAttr(getType(), cst.getValue().sext(destWidth)); - return getIntAttr(getType(), cst.getValue().zext(destWidth)); + if (inputType.isSigned() && cst->getBitWidth()) + return getIntAttr(getType(), cst->sext(destWidth)); + return getIntAttr(getType(), cst->zext(destWidth)); } return {}; } -void PadPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - OpFoldResult ShlPrimOp::fold(ArrayRef operands) { auto input = this->input(); auto inputType = input.getType().cast(); @@ -1166,13 +1237,12 @@ OpFoldResult ShlPrimOp::fold(ArrayRef operands) { return input; // Constant fold. - if (auto cst = operands[0].dyn_cast_or_null()) { + if (auto cst = getConstant(operands[0])) { auto inputWidth = inputType.getWidthOrSentinel(); if (inputWidth != -1) { auto resultWidth = inputWidth + shiftAmount; shiftAmount = std::min(shiftAmount, resultWidth); - return getIntAttr(getType(), - cst.getValue().zext(resultWidth).shl(shiftAmount)); + return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount)); } } return {}; @@ -1199,12 +1269,12 @@ OpFoldResult ShrPrimOp::fold(ArrayRef operands) { return getIntAttr(getType(), APInt(1, 0)); // Constant fold. - if (auto cst = operands[0].dyn_cast_or_null()) { + if (auto cst = getConstant(operands[0])) { APInt value; if (inputType.isSigned()) - value = cst.getValue().ashr(std::min(shiftAmount, inputWidth - 1)); + value = cst->ashr(std::min(shiftAmount, inputWidth - 1)); else - value = cst.getValue().lshr(std::min(shiftAmount, inputWidth)); + value = cst->lshr(std::min(shiftAmount, inputWidth)); auto resultWidth = std::max(inputWidth - shiftAmount, 1); return getIntAttr(getType(), value.truncOrSelf(resultWidth)); } @@ -1249,11 +1319,11 @@ LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op, OpFoldResult HeadPrimOp::fold(ArrayRef operands) { if (hasKnownWidthIntTypes(*this)) - if (auto attr = operands[0].dyn_cast_or_null()) { + if (auto cst = getConstant(operands[0])) { int shiftAmount = input().getType().cast().getWidthOrSentinel() - amount(); - return getIntAttr( - getType(), attr.getValue().lshr(shiftAmount).truncOrSelf(amount())); + return getIntAttr(getType(), + cst->lshr(shiftAmount).truncOrSelf(amount())); } return {}; @@ -1261,9 +1331,9 @@ OpFoldResult HeadPrimOp::fold(ArrayRef operands) { OpFoldResult TailPrimOp::fold(ArrayRef operands) { if (hasKnownWidthIntTypes(*this)) - if (auto attr = operands[0].dyn_cast_or_null()) - return getIntAttr(getType(), attr.getValue().truncOrSelf( - getType().getWidthOrSentinel())); + if (auto cst = getConstant(operands[0])) + return getIntAttr(getType(), + cst->truncOrSelf(getType().getWidthOrSentinel())); return {}; } @@ -1282,18 +1352,18 @@ LogicalResult TailPrimOp::canonicalize(TailPrimOp op, LogicalResult SubaccessOp::canonicalize(SubaccessOp op, PatternRewriter &rewriter) { - if (auto index = op.index().getDefiningOp()) { - if (auto constIndex = dyn_cast(index)) { - // The SubindexOp require the index value to be unsigned 32-bits - // integer. - auto value = constIndex.value().getExtValue(); - auto valueAttr = rewriter.getI32IntegerAttr(value); - rewriter.replaceOpWithNewOp(op, op.result().getType(), - op.input(), valueAttr); - return success(); - } - } - return failure(); + return canonicalizePrimOp( + op, rewriter, [&](ArrayRef operands) -> OpFoldResult { + if (auto constIndex = getConstant(operands[1])) { + // The SubindexOp require the index value to be unsigned 32-bits + // integer. + auto value = constIndex->getExtValue(); + auto valueAttr = rewriter.getI32IntegerAttr(value); + return rewriter.createOrFold( + op.getLoc(), op.result().getType(), op.input(), valueAttr); + } + return {}; + }); } //===----------------------------------------------------------------------===// diff --git a/test/Dialect/FIRRTL/SFCTests/constantProp.mlir b/test/Dialect/FIRRTL/SFCTests/constantProp.mlir index 345baf4d44..94503c5313 100644 --- a/test/Dialect/FIRRTL/SFCTests/constantProp.mlir +++ b/test/Dialect/FIRRTL/SFCTests/constantProp.mlir @@ -181,10 +181,8 @@ firrtl.circuit "padZeroReg" { %n = firrtl.node %c171_ui8 : !firrtl.uint<8> %1 = firrtl.cat %n, %r : (!firrtl.uint<8>, !firrtl.uint<8>) -> !firrtl.uint<16> firrtl.connect %z, %1 : !firrtl.uint<16>, !firrtl.uint<16> - // CHECK: %[[C7:.+]] = firrtl.constant 171 : !firrtl.uint<8> - // CHECK: %[[invalid:.+]] = firrtl.invalidvalue : !firrtl.uint<8> - // CHECK: %[[C7_0:.+]] = firrtl.cat %c171_ui8, %invalid_ui8 - // CHECK-NEXT: firrtl.connect %z, %[[C7_0]] : !firrtl.uint<16>, !firrtl.uint<16> + // CHECK: %[[TMP:.+]] = firrtl.constant 43776 : !firrtl.uint<16> + // CHECK-NEXT: firrtl.connect %z, %[[TMP]] : !firrtl.uint<16>, !firrtl.uint<16> } } diff --git a/test/Dialect/FIRRTL/SFCTests/invalid-reg-fail.fir b/test/Dialect/FIRRTL/SFCTests/invalid-reg-fail.fir deleted file mode 100644 index bfb7603e99..0000000000 --- a/test/Dialect/FIRRTL/SFCTests/invalid-reg-fail.fir +++ /dev/null @@ -1,844 +0,0 @@ -; RUN: firtool -split-input-file -verilog %s | FileCheck %s -; XFAIL: * - -; This test checks register removal behavior for situations where the register -; is invalidated _through a primitive operation_. This is intended to tease out -; gnarly bugs where, due to a combination of canonicalization, folding, and -; constant propagation, CIRCT does not remove registers which the Scala FIRRTL -; Compiler (SFC) does. The CHECK/CHECK-NOT statements in this test indicate the -; SFC behavior. -; -; This test contains FAILING cases which should be fixed. For passing cases, -; see invalid-reg-pass.fir. -; -; The FIRRTL circuits in this file were generated using: -; https://github.com/seldridge/firrtl-torture/blob/main/Invalid.scala - -circuit add : - module add : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<5> - output out_1 : UInt<5> - output out_2 : UInt<5> - output out_3 : UInt<5> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = add(in_1, in_0) - node _T_1 = tail(_T, 1) - r_0 <= _T_1 - out_0 <= r_0 - node _T_2 = add(in_1, invalid) - node _T_3 = tail(_T_2, 1) - r_1 <= _T_3 - out_1 <= r_1 - node _T_4 = add(invalid, in_0) - node _T_5 = tail(_T_4, 1) - r_2 <= _T_5 - out_2 <= r_2 - node _T_6 = add(invalid, invalid) - node _T_7 = tail(_T_6, 1) - r_3 <= _T_7 - out_3 <= r_3 - - ; CHECK-LABEL: module add - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit andr : - module andr : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = andr(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = andr(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module andr - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit asSInt : - module asSInt : - input clock : Clock - input reset : UInt<1> - input in : UInt<2> - output out_0 : SInt<2> - output out_1 : SInt<2> - - wire invalid : UInt<2> - invalid is invalid - reg r_0 : SInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : SInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = asSInt(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = asSInt(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module asSInt - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit asUInt : - module asUInt : - input clock : Clock - input reset : UInt<1> - input in : SInt<2> - output out_0 : UInt<2> - output out_1 : UInt<2> - - wire invalid : SInt<2> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = asUInt(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = asUInt(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module asUInt - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit bits : - module bits : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<2> - output out_1 : UInt<2> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = bits(in, 3, 2) - r_0 <= _T - out_0 <= r_0 - node _T_1 = bits(invalid, 3, 2) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module bits - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit cat : - module cat : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<2> - input in_1 : UInt<2> - output out_0 : UInt<4> - output out_1 : UInt<4> - output out_2 : UInt<4> - output out_3 : UInt<4> - - wire invalid : UInt<2> - invalid is invalid - reg r_0 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = cat(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = cat(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = cat(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = cat(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module cat - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit div : - module div : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<4> - output out_1 : UInt<4> - output out_2 : UInt<4> - output out_3 : UInt<4> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = div(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = div(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = div(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = div(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module div - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit dshl : - module dshl : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<2> - input in_1 : UInt<2> - output out_0 : UInt<5> - output out_1 : UInt<5> - output out_2 : UInt<5> - output out_3 : UInt<5> - - wire invalid : UInt<2> - invalid is invalid - reg r_0 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = dshl(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = dshl(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = dshl(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = dshl(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module dshl - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit dshr : - module dshr : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<2> - input in_1 : UInt<2> - output out_0 : UInt<2> - output out_1 : UInt<2> - output out_2 : UInt<2> - output out_3 : UInt<2> - - wire invalid : UInt<2> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = gt(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = gt(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = gt(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = gt(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module dshr - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit geq : - module geq : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - output out_2 : UInt<1> - output out_3 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = geq(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = geq(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = geq(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = geq(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module geq - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit gt : - module gt : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - output out_2 : UInt<1> - output out_3 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = gt(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = gt(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = gt(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = gt(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module gt - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit head : - module head : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<2> - output out_1 : UInt<2> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = head(in, 2) - r_0 <= _T - out_0 <= r_0 - node _T_1 = head(invalid, 2) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module head - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit leq : - module leq : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - output out_2 : UInt<1> - output out_3 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = leq(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = leq(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = leq(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = leq(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module leq - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit lt : - module lt : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - output out_2 : UInt<1> - output out_3 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = lt(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = lt(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = lt(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = lt(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module lt - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit mul : - module mul : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<8> - output out_1 : UInt<8> - output out_2 : UInt<8> - output out_3 : UInt<8> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<8>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<8>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<8>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<8>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = mul(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = mul(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = mul(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = mul(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module mul - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK: r_3 - - ; // ----- - -circuit not : - module not : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<4> - output out_1 : UInt<4> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = not(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = not(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module not - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit orr : - module orr : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = orr(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = orr(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module orr - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit rem : - module rem : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<4> - output out_1 : UInt<4> - output out_2 : UInt<4> - output out_3 : UInt<4> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = rem(in_1, in_0) - r_0 <= _T - out_0 <= r_0 - node _T_1 = rem(in_1, invalid) - r_1 <= _T_1 - out_1 <= r_1 - node _T_2 = rem(invalid, in_0) - r_2 <= _T_2 - out_2 <= r_2 - node _T_3 = rem(invalid, invalid) - r_3 <= _T_3 - out_3 <= r_3 - - ; CHECK-LABEL: module rem - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit shl : - module shl : - input clock : Clock - input reset : UInt<1> - input in : UInt<2> - output out_0 : UInt<4> - output out_1 : UInt<4> - - wire invalid : UInt<2> - invalid is invalid - reg r_0 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<4>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = shl(in, 2) - r_0 <= _T - out_0 <= r_0 - node _T_1 = shl(invalid, 2) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module shl - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit shr : - module shr : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<2> - output out_1 : UInt<2> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = shr(in, 2) - r_0 <= _T - out_0 <= r_0 - node _T_1 = shr(invalid, 2) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module shr - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit sub : - module sub : - input clock : Clock - input reset : UInt<1> - input in_0 : UInt<4> - input in_1 : UInt<4> - output out_0 : UInt<5> - output out_1 : UInt<5> - output out_2 : UInt<5> - output out_3 : UInt<5> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_1) - reg r_2 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_2) - reg r_3 : UInt<5>, clock with : - reset => (UInt<1>("h0"), r_3) - node _T = sub(in_1, in_0) - node _T_1 = tail(_T, 1) - r_0 <= _T_1 - out_0 <= r_0 - node _T_2 = sub(in_1, invalid) - node _T_3 = tail(_T_2, 1) - r_1 <= _T_3 - out_1 <= r_1 - node _T_4 = sub(invalid, in_0) - node _T_5 = tail(_T_4, 1) - r_2 <= _T_5 - out_2 <= r_2 - node _T_6 = sub(invalid, invalid) - node _T_7 = tail(_T_6, 1) - r_3 <= _T_7 - out_3 <= r_3 - - ; CHECK-LABEL: module sub - ; CHECK: r_0 - ; CHECK: r_1 - ; CHECK: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit tail : - module tail : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<2> - output out_1 : UInt<2> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<2>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = tail(in, 2) - r_0 <= _T - out_0 <= r_0 - node _T_1 = tail(invalid, 2) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module tail - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 - - ; // ----- - -circuit xorr : - module xorr : - input clock : Clock - input reset : UInt<1> - input in : UInt<4> - output out_0 : UInt<1> - output out_1 : UInt<1> - - wire invalid : UInt<4> - invalid is invalid - reg r_0 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_0) - reg r_1 : UInt<1>, clock with : - reset => (UInt<1>("h0"), r_1) - node _T = xorr(in) - r_0 <= _T - out_0 <= r_0 - node _T_1 = xorr(invalid) - r_1 <= _T_1 - out_1 <= r_1 - - ; CHECK-LABEL: module xorr - ; CHECK: r_0 - ; CHECK-NOT: r_1 - ; CHECK-NOT: r_2 - ; CHECK-NOT: r_3 diff --git a/test/Dialect/FIRRTL/SFCTests/invalid-reg-pass.fir b/test/Dialect/FIRRTL/SFCTests/invalid-reg-pass.fir index ab699f2935..f7f33d4ebc 100644 --- a/test/Dialect/FIRRTL/SFCTests/invalid-reg-pass.fir +++ b/test/Dialect/FIRRTL/SFCTests/invalid-reg-pass.fir @@ -13,6 +13,52 @@ ; The FIRRTL circuits in this file were generated using: ; https://github.com/seldridge/firrtl-torture/blob/main/Invalid.scala +circuit add : + module add : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<5> + output out_1 : UInt<5> + output out_2 : UInt<5> + output out_3 : UInt<5> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = add(in_1, in_0) + node _T_1 = tail(_T, 1) + r_0 <= _T_1 + out_0 <= r_0 + node _T_2 = add(in_1, invalid) + node _T_3 = tail(_T_2, 1) + r_1 <= _T_3 + out_1 <= r_1 + node _T_4 = add(invalid, in_0) + node _T_5 = tail(_T_4, 1) + r_2 <= _T_5 + out_2 <= r_2 + node _T_6 = add(invalid, invalid) + node _T_7 = tail(_T_6, 1) + r_3 <= _T_7 + out_3 <= r_3 + + ; CHECK-LABEL: module add + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + circuit and : module and : input clock : Clock @@ -78,7 +124,7 @@ circuit asAsyncReset : ; CHECK-LABEL: module asAsyncReset ; CHECK: r_0 - ; CHECK: r_1 + ; CHECK-NOT: r_1 <-- fixed; upstream to Scala FIRRTL impl? ; // ----- @@ -105,7 +151,7 @@ circuit asClock : ; CHECK-LABEL: module asClock ; CHECK: r_0 - ; CHECK: r_1 + ; CHECK-NOT: r_1 <-- fixed; upstream to Scala FIRRTL impl? ; CHECK-NOT: r_2 ; CHECK-NOT: r_3 @@ -134,7 +180,7 @@ circuit cvt : ; CHECK-LABEL: module cvt ; CHECK: r_0 - ; CHECK: r_1 + ; CHECK-NOT: r_1 <-- fixed; upstream to Scala FIRRTL impl? ; CHECK-NOT: r_2 ; CHECK-NOT: r_3 @@ -366,3 +412,788 @@ circuit xor : ; CHECK: r_1 ; CHECK: r_2 ; CHECK-NOT: r_3 + + ; // ----- + +circuit andr : + module andr : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = andr(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = andr(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module andr + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit asSInt : + module asSInt : + input clock : Clock + input reset : UInt<1> + input in : UInt<2> + output out_0 : SInt<2> + output out_1 : SInt<2> + + wire invalid : UInt<2> + invalid is invalid + reg r_0 : SInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : SInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = asSInt(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = asSInt(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module asSInt + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit asUInt : + module asUInt : + input clock : Clock + input reset : UInt<1> + input in : SInt<2> + output out_0 : UInt<2> + output out_1 : UInt<2> + + wire invalid : SInt<2> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = asUInt(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = asUInt(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module asUInt + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit bits : + module bits : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<2> + output out_1 : UInt<2> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = bits(in, 3, 2) + r_0 <= _T + out_0 <= r_0 + node _T_1 = bits(invalid, 3, 2) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module bits + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit cat : + module cat : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<2> + input in_1 : UInt<2> + output out_0 : UInt<4> + output out_1 : UInt<4> + output out_2 : UInt<4> + output out_3 : UInt<4> + + wire invalid : UInt<2> + invalid is invalid + reg r_0 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = cat(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = cat(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = cat(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = cat(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module cat + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit dshl : + module dshl : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<2> + input in_1 : UInt<2> + output out_0 : UInt<5> + output out_1 : UInt<5> + output out_2 : UInt<5> + output out_3 : UInt<5> + + wire invalid : UInt<2> + invalid is invalid + reg r_0 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = dshl(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = dshl(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = dshl(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = dshl(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module dshl + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit dshr : + module dshr : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<2> + input in_1 : UInt<2> + output out_0 : UInt<2> + output out_1 : UInt<2> + output out_2 : UInt<2> + output out_3 : UInt<2> + + wire invalid : UInt<2> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = dshr(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = dshr(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = dshr(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = dshr(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module dshr + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit head : + module head : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<2> + output out_1 : UInt<2> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = head(in, 2) + r_0 <= _T + out_0 <= r_0 + node _T_1 = head(invalid, 2) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module head + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit lt : + module lt : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + output out_2 : UInt<1> + output out_3 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = lt(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = lt(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = lt(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = lt(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module lt + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit gt : + module gt : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + output out_2 : UInt<1> + output out_3 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = gt(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = gt(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = gt(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = gt(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module gt + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit leq : + module leq : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + output out_2 : UInt<1> + output out_3 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = leq(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = leq(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = leq(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = leq(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module leq + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit geq : + module geq : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + output out_2 : UInt<1> + output out_3 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = geq(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = geq(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = geq(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = geq(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module geq + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit mul : + module mul : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<8> + output out_1 : UInt<8> + output out_2 : UInt<8> + output out_3 : UInt<8> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<8>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<8>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<8>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<8>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = mul(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = mul(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = mul(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = mul(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module mul + ; CHECK: r_0 + ; CHECK-NOT: r_1 <-- fixed; upstream to Scala FIRRTL impl? + ; CHECK-NOT: r_2 <-- fixed; upstream to Scala FIRRTL impl? + ; CHECK-NOT: r_3 <-- fixed; upstream to Scala FIRRTL impl? + + ; // ----- + +circuit not : + module not : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<4> + output out_1 : UInt<4> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = not(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = not(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module not + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit orr : + module orr : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = orr(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = orr(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module orr + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit shl : + module shl : + input clock : Clock + input reset : UInt<1> + input in : UInt<2> + output out_0 : UInt<4> + output out_1 : UInt<4> + + wire invalid : UInt<2> + invalid is invalid + reg r_0 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = shl(in, 2) + r_0 <= _T + out_0 <= r_0 + node _T_1 = shl(invalid, 2) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module shl + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit shr : + module shr : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<2> + output out_1 : UInt<2> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = shr(in, 2) + r_0 <= _T + out_0 <= r_0 + node _T_1 = shr(invalid, 2) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module shr + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit sub : + module sub : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<5> + output out_1 : UInt<5> + output out_2 : UInt<5> + output out_3 : UInt<5> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<5>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = sub(in_1, in_0) + node _T_1 = tail(_T, 1) + r_0 <= _T_1 + out_0 <= r_0 + node _T_2 = sub(in_1, invalid) + node _T_3 = tail(_T_2, 1) + r_1 <= _T_3 + out_1 <= r_1 + node _T_4 = sub(invalid, in_0) + node _T_5 = tail(_T_4, 1) + r_2 <= _T_5 + out_2 <= r_2 + node _T_6 = sub(invalid, invalid) + node _T_7 = tail(_T_6, 1) + r_3 <= _T_7 + out_3 <= r_3 + + ; CHECK-LABEL: module sub + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit tail : + module tail : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<2> + output out_1 : UInt<2> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<2>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = tail(in, 2) + r_0 <= _T + out_0 <= r_0 + node _T_1 = tail(invalid, 2) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module tail + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 + + ; // ----- + +circuit div : + module div : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<4> + output out_1 : UInt<4> + output out_2 : UInt<4> + output out_3 : UInt<4> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = div(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = div(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = div(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = div(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module div + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK-NOT: r_2 <-- fixed; upstream to Scala FIRRTL impl? + ; CHECK-NOT: r_3 + + ; // ----- + +circuit rem : + module rem : + input clock : Clock + input reset : UInt<1> + input in_0 : UInt<4> + input in_1 : UInt<4> + output out_0 : UInt<4> + output out_1 : UInt<4> + output out_2 : UInt<4> + output out_3 : UInt<4> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_1) + reg r_2 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_2) + reg r_3 : UInt<4>, clock with : + reset => (UInt<1>("h0"), r_3) + node _T = rem(in_1, in_0) + r_0 <= _T + out_0 <= r_0 + node _T_1 = rem(in_1, invalid) + r_1 <= _T_1 + out_1 <= r_1 + node _T_2 = rem(invalid, in_0) + r_2 <= _T_2 + out_2 <= r_2 + node _T_3 = rem(invalid, invalid) + r_3 <= _T_3 + out_3 <= r_3 + + ; CHECK-LABEL: module rem + ; CHECK: r_0 + ; CHECK: r_1 + ; CHECK-NOT: r_2 <-- fixed; upstream to Scala FIRRTL impl? + ; CHECK-NOT: r_3 + + ; // ----- + +circuit xorr : + module xorr : + input clock : Clock + input reset : UInt<1> + input in : UInt<4> + output out_0 : UInt<1> + output out_1 : UInt<1> + + wire invalid : UInt<4> + invalid is invalid + reg r_0 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_0) + reg r_1 : UInt<1>, clock with : + reset => (UInt<1>("h0"), r_1) + node _T = xorr(in) + r_0 <= _T + out_0 <= r_0 + node _T_1 = xorr(invalid) + r_1 <= _T_1 + out_1 <= r_1 + + ; CHECK-LABEL: module xorr + ; CHECK: r_0 + ; CHECK-NOT: r_1 + ; CHECK-NOT: r_2 + ; CHECK-NOT: r_3 diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index 285d8f9b3a..b2b5426c8b 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -10,6 +10,10 @@ firrtl.module @Casts(in %ui1 : !firrtl.uint<1>, in %si1 : !firrtl.sint<1>, %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> %c1_si1 = firrtl.constant 1 : !firrtl.sint<1> + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + %invalid_si1 = firrtl.invalidvalue : !firrtl.sint<1> + %invalid_clock = firrtl.invalidvalue : !firrtl.clock + %invalid_asyncreset = firrtl.invalidvalue : !firrtl.asyncreset // No effect // CHECK: firrtl.connect %out_ui1, %ui1 : !firrtl.uint<1>, !firrtl.uint<1> @@ -38,6 +42,20 @@ firrtl.module @Casts(in %ui1 : !firrtl.uint<1>, in %si1 : !firrtl.sint<1>, // CHECK: firrtl.connect %out_asyncreset, %c1_asyncreset : !firrtl.asyncreset, !firrtl.asyncreset %7 = firrtl.asAsyncReset %c1_ui1 : (!firrtl.uint<1>) -> !firrtl.asyncreset firrtl.connect %out_asyncreset, %7 : !firrtl.asyncreset, !firrtl.asyncreset + + // Invalid values + // CHECK: firrtl.connect %out_ui1, %c0_ui1 : !firrtl.uint<1>, !firrtl.uint<1> + %8 = firrtl.asUInt %invalid_si1 : (!firrtl.sint<1>) -> !firrtl.uint<1> + firrtl.connect %out_ui1, %8 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %out_si1, %c0_si1 : !firrtl.sint<1>, !firrtl.sint<1> + %9 = firrtl.asSInt %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.sint<1> + firrtl.connect %out_si1, %9 : !firrtl.sint<1>, !firrtl.sint<1> + // CHECK: firrtl.connect %out_clock, %c0_clock : !firrtl.clock, !firrtl.clock + %10 = firrtl.asClock %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.clock + firrtl.connect %out_clock, %10 : !firrtl.clock, !firrtl.clock + // CHECK: firrtl.connect %out_asyncreset, %c0_asyncreset : !firrtl.asyncreset, !firrtl.asyncreset + %11 = firrtl.asAsyncReset %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.asyncreset + firrtl.connect %out_asyncreset, %11 : !firrtl.asyncreset, !firrtl.asyncreset } // CHECK-LABEL: firrtl.module @Div @@ -282,7 +300,19 @@ firrtl.module @EQ(in %in1: !firrtl.uint<1>, %4 = firrtl.eq %in4, %c0_ui1 : (!firrtl.uint<4>, !firrtl.uint<1>) -> !firrtl.uint<1> firrtl.connect %out, %4 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: [[ORR:%.+]] = firrtl.orr %in4 + // CHECK-NEXT: firrtl.not [[ORR]] + // CHECK-NEXT: firrtl.connect + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + %5 = firrtl.eq %in1, %invalid_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.connect %out, %5 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK-NEXT: firrtl.not %in1 + // CHECK-NEXT: firrtl.connect + + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %6 = firrtl.eq %in4, %invalid_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1> + firrtl.connect %out, %6 : !firrtl.uint<1>, !firrtl.uint<1> // CHECK: [[ORR:%.+]] = firrtl.orr %in4 // CHECK-NEXT: firrtl.not [[ORR]] // CHECK-NEXT: firrtl.connect @@ -308,9 +338,15 @@ firrtl.module @NEQ(in %in1: !firrtl.uint<1>, // CHECK: firrtl.orr %in4 // CHECK-NEXT: firrtl.connect - %c15_ui4 = firrtl.constant 15 : !firrtl.uint<4> - %3 = firrtl.neq %in4, %c15_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1> + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %3 = firrtl.neq %in4, %invalid_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1> firrtl.connect %out, %3 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.orr %in4 + // CHECK-NEXT: firrtl.connect + + %c15_ui4 = firrtl.constant 15 : !firrtl.uint<4> + %4 = firrtl.neq %in4, %c15_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1> + firrtl.connect %out, %4 : !firrtl.uint<1>, !firrtl.uint<1> // CHECK: [[ANDR:%.+]] = firrtl.andr %in4 // CHECK-NEXT: firrtl.not [[ANDR]] // CHECK-NEXT: firrtl.connect @@ -319,7 +355,8 @@ firrtl.module @NEQ(in %in1: !firrtl.uint<1>, // CHECK-LABEL: firrtl.module @Cat firrtl.module @Cat(in %in4: !firrtl.uint<4>, out %out4: !firrtl.uint<4>, - out %outcst: !firrtl.uint<8>) { + out %outcst: !firrtl.uint<8>, + out %outcst2: !firrtl.uint<8>) { // CHECK: firrtl.connect %out4, %in4 %0 = firrtl.bits %in4 3 to 2 : (!firrtl.uint<4>) -> !firrtl.uint<2> @@ -332,7 +369,12 @@ firrtl.module @Cat(in %in4: !firrtl.uint<4>, %c3_ui4 = firrtl.constant 3 : !firrtl.uint<4> %3 = firrtl.cat %c15_ui4, %c3_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<8> firrtl.connect %outcst, %3 : !firrtl.uint<8>, !firrtl.uint<8> - } + + // CHECK: firrtl.connect %outcst2, %c0_ui8 + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %4 = firrtl.cat %invalid_ui4, %invalid_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<8> + firrtl.connect %outcst2, %4 : !firrtl.uint<8>, !firrtl.uint<8> +} // CHECK-LABEL: firrtl.module @Bits firrtl.module @Bits(in %in1: !firrtl.uint<1>, @@ -363,6 +405,11 @@ firrtl.module @Bits(in %in1: !firrtl.uint<1>, // CHECK: firrtl.connect %out1, %in1 %5 = firrtl.bits %in1 0 to 0 : (!firrtl.uint<1>) -> !firrtl.uint<1> firrtl.connect %out1, %5 : !firrtl.uint<1>, !firrtl.uint<1> + + // CHECK: firrtl.connect %out2, %c0_ui2 + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %6 = firrtl.bits %invalid_ui4 2 to 1 : (!firrtl.uint<4>) -> !firrtl.uint<2> + firrtl.connect %out2, %6 : !firrtl.uint<2>, !firrtl.uint<2> } // CHECK-LABEL: firrtl.module @Head @@ -383,6 +430,11 @@ firrtl.module @Head(in %in4u: !firrtl.uint<4>, %c10_ui4 = firrtl.constant 10 : !firrtl.uint<4> %2 = firrtl.head %c10_ui4, 3 : (!firrtl.uint<4>) -> !firrtl.uint<3> firrtl.connect %out3u, %2 : !firrtl.uint<3>, !firrtl.uint<3> + + // CHECK: firrtl.connect %out3u, %c0_ui3 + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %3 = firrtl.head %invalid_ui4, 3 : (!firrtl.uint<4>) -> !firrtl.uint<3> + firrtl.connect %out3u, %3 : !firrtl.uint<3>, !firrtl.uint<3> } // CHECK-LABEL: firrtl.module @Mux @@ -417,6 +469,11 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>, // CHECK: firrtl.connect %out, %invalid_ui4 %7 = firrtl.mux (%cond, %invalid_ui4, %invalid_ui4) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> firrtl.connect %out, %7 : !firrtl.uint<4>, !firrtl.uint<4> + + // CHECK: firrtl.connect %out, %c7_ui4 + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + %8 = firrtl.mux (%invalid_ui1, %in, %c7_ui4) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.connect %out, %8 : !firrtl.uint<4>, !firrtl.uint<4> } // CHECK-LABEL: firrtl.module @Pad @@ -451,6 +508,11 @@ firrtl.module @Shl(in %in1u: !firrtl.uint<1>, %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> %1 = firrtl.shl %c1_ui1, 3 : (!firrtl.uint<1>) -> !firrtl.uint<4> firrtl.connect %outu, %1 : !firrtl.uint<4>, !firrtl.uint<4> + + // CHECK: firrtl.connect %outu, %c0_ui4 + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + %2 = firrtl.shl %invalid_ui1, 3 : (!firrtl.uint<1>) -> !firrtl.uint<4> + firrtl.connect %outu, %2 : !firrtl.uint<4>, !firrtl.uint<4> } // CHECK-LABEL: firrtl.module @Shr @@ -512,6 +574,11 @@ firrtl.module @Shr(in %in1u: !firrtl.uint<1>, %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> %9 = firrtl.dshr %in0u, %c1_ui1 : (!firrtl.uint<0>, !firrtl.uint<1>) -> !firrtl.uint<0> firrtl.connect %out1u, %9 : !firrtl.uint<1>, !firrtl.uint<0> + + // CHECK: firrtl.connect %out1u, %c0_ui1 + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %10 = firrtl.shr %invalid_ui4, 3 : (!firrtl.uint<4>) -> !firrtl.uint<1> + firrtl.connect %out1u, %10 : !firrtl.uint<1>, !firrtl.uint<1> } // CHECK-LABEL: firrtl.module @Tail @@ -532,31 +599,95 @@ firrtl.module @Tail(in %in4u: !firrtl.uint<4>, %c10_ui4 = firrtl.constant 10 : !firrtl.uint<4> %2 = firrtl.tail %c10_ui4, 1 : (!firrtl.uint<4>) -> !firrtl.uint<3> firrtl.connect %out3u, %2 : !firrtl.uint<3>, !firrtl.uint<3> + + // CHECK: firrtl.connect %out3u, %c0_ui3 + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %3 = firrtl.tail %invalid_ui4, 1 : (!firrtl.uint<4>) -> !firrtl.uint<3> + firrtl.connect %out3u, %3 : !firrtl.uint<3>, !firrtl.uint<3> } // CHECK-LABEL: firrtl.module @Andr -firrtl.circuit "Andr" { - firrtl.module @Andr(out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>, - out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>) { - %c2_ui2 = firrtl.constant 2 : !firrtl.uint<2> - %c3_ui2 = firrtl.constant 3 : !firrtl.uint<2> - %cn2_si2 = firrtl.constant -2 : !firrtl.sint<2> - %cn1_si2 = firrtl.constant -1 : !firrtl.sint<2> - %0 = firrtl.andr %c2_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> - %1 = firrtl.andr %c3_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> - %2 = firrtl.andr %cn2_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> - %3 = firrtl.andr %cn1_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> - // CHECK: %[[ONE:.+]] = firrtl.constant 1 : !firrtl.uint<1> - // CHECK: %[[ZERO:.+]] = firrtl.constant 0 : !firrtl.uint<1> - // CHECK: firrtl.connect %a, %[[ZERO]] - firrtl.connect %a, %0 : !firrtl.uint<1>, !firrtl.uint<1> - // CHECK: firrtl.connect %b, %[[ONE]] - firrtl.connect %b, %1 : !firrtl.uint<1>, !firrtl.uint<1> - // CHECK: firrtl.connect %c, %[[ZERO]] - firrtl.connect %c, %2 : !firrtl.uint<1>, !firrtl.uint<1> - // CHECK: firrtl.connect %d, %[[ONE]] - firrtl.connect %d, %3 : !firrtl.uint<1>, !firrtl.uint<1> - } +firrtl.module @Andr(out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>, + out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>, + out %e: !firrtl.uint<1>) { + %invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2> + %c2_ui2 = firrtl.constant 2 : !firrtl.uint<2> + %c3_ui2 = firrtl.constant 3 : !firrtl.uint<2> + %cn2_si2 = firrtl.constant -2 : !firrtl.sint<2> + %cn1_si2 = firrtl.constant -1 : !firrtl.sint<2> + %0 = firrtl.andr %c2_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %1 = firrtl.andr %c3_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %2 = firrtl.andr %cn2_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %3 = firrtl.andr %cn1_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %4 = firrtl.andr %invalid_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + // CHECK: %[[ZERO:.+]] = firrtl.constant 0 : !firrtl.uint<1> + // CHECK: %[[ONE:.+]] = firrtl.constant 1 : !firrtl.uint<1> + // CHECK: firrtl.connect %a, %[[ZERO]] + firrtl.connect %a, %0 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %b, %[[ONE]] + firrtl.connect %b, %1 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %c, %[[ZERO]] + firrtl.connect %c, %2 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %d, %[[ONE]] + firrtl.connect %d, %3 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %e, %[[ZERO]] + firrtl.connect %e, %4 : !firrtl.uint<1>, !firrtl.uint<1> +} + +// CHECK-LABEL: firrtl.module @Orr +firrtl.module @Orr(out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>, + out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>, + out %e: !firrtl.uint<1>) { + %invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2> + %c0_ui2 = firrtl.constant 0 : !firrtl.uint<2> + %c2_ui2 = firrtl.constant 2 : !firrtl.uint<2> + %cn0_si2 = firrtl.constant 0 : !firrtl.sint<2> + %cn2_si2 = firrtl.constant -2 : !firrtl.sint<2> + %0 = firrtl.orr %c0_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %1 = firrtl.orr %c2_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %2 = firrtl.orr %cn0_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %3 = firrtl.orr %cn2_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %4 = firrtl.orr %invalid_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + // CHECK: %[[ZERO:.+]] = firrtl.constant 0 : !firrtl.uint<1> + // CHECK: %[[ONE:.+]] = firrtl.constant 1 : !firrtl.uint<1> + // CHECK: firrtl.connect %a, %[[ZERO]] + firrtl.connect %a, %0 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %b, %[[ONE]] + firrtl.connect %b, %1 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %c, %[[ZERO]] + firrtl.connect %c, %2 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %d, %[[ONE]] + firrtl.connect %d, %3 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %e, %[[ZERO]] + firrtl.connect %e, %4 : !firrtl.uint<1>, !firrtl.uint<1> +} + +// CHECK-LABEL: firrtl.module @Xorr +firrtl.module @Xorr(out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>, + out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>, + out %e: !firrtl.uint<1>) { + %invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2> + %c3_ui2 = firrtl.constant 3 : !firrtl.uint<2> + %c2_ui2 = firrtl.constant 2 : !firrtl.uint<2> + %cn1_si2 = firrtl.constant -1 : !firrtl.sint<2> + %cn2_si2 = firrtl.constant -2 : !firrtl.sint<2> + %0 = firrtl.xorr %c3_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %1 = firrtl.xorr %c2_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %2 = firrtl.xorr %cn1_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %3 = firrtl.xorr %cn2_si2 : (!firrtl.sint<2>) -> !firrtl.uint<1> + %4 = firrtl.xorr %invalid_ui2 : (!firrtl.uint<2>) -> !firrtl.uint<1> + // CHECK: %[[ZERO:.+]] = firrtl.constant 0 : !firrtl.uint<1> + // CHECK: %[[ONE:.+]] = firrtl.constant 1 : !firrtl.uint<1> + // CHECK: firrtl.connect %a, %[[ZERO]] + firrtl.connect %a, %0 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %b, %[[ONE]] + firrtl.connect %b, %1 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %c, %[[ZERO]] + firrtl.connect %c, %2 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %d, %[[ONE]] + firrtl.connect %d, %3 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.connect %e, %[[ZERO]] + firrtl.connect %e, %4 : !firrtl.uint<1>, !firrtl.uint<1> } // CHECK-LABEL: firrtl.module @Reduce @@ -581,6 +712,12 @@ firrtl.module @subaccess(out %result: !firrtl.uint<8>, in %vec0: !firrtl.vector< %c11_ui8 = firrtl.constant 11 : !firrtl.uint<8> %0 = firrtl.subaccess %vec0[%c11_ui8] : !firrtl.vector, 16>, !firrtl.uint<8> firrtl.connect %result, %0 :!firrtl.uint<8>, !firrtl.uint<8> + + // CHECK: [[TMP:%.+]] = firrtl.subindex %vec0[0] + // CHECK-NEXT: firrtl.connect %result, [[TMP]] + %invalid_ui8 = firrtl.invalidvalue : !firrtl.uint<8> + %1 = firrtl.subaccess %vec0[%invalid_ui8] : !firrtl.vector, 16>, !firrtl.uint<8> + firrtl.connect %result, %1 :!firrtl.uint<8>, !firrtl.uint<8> } // CHECK-LABEL: firrtl.module @issue326 @@ -1553,6 +1690,20 @@ firrtl.module @add_cst_prop4(out %out_b: !firrtl.uint<5>) { firrtl.connect %out_b, %add2 : !firrtl.uint<5>, !firrtl.uint<5> } +// CHECK-LABEL: @add_cst_prop5 +// CHECK: %[[pad:.+]] = firrtl.pad %tmp_a, 5 +// CHECK-NEXT: firrtl.connect %out_b, %[[pad]] +// CHECK-NEXT: %[[pad:.+]] = firrtl.pad %tmp_a, 5 +// CHECK_NEXT: firrtl.connect %out_b, %[[pad]] +firrtl.module @add_cst_prop5(out %out_b: !firrtl.uint<5>) { + %tmp_a = firrtl.wire : !firrtl.uint<4> + %c0_ui4 = firrtl.constant 0 : !firrtl.uint<4> + %add = firrtl.add %tmp_a, %c0_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<5> + firrtl.connect %out_b, %add : !firrtl.uint<5>, !firrtl.uint<5> + %add2 = firrtl.add %c0_ui4, %tmp_a : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<5> + firrtl.connect %out_b, %add2 : !firrtl.uint<5>, !firrtl.uint<5> +} + // CHECK-LABEL: @sub_cst_prop1 // CHECK-NEXT: %c1_ui9 = firrtl.constant 1 : !firrtl.uint<9> // CHECK-NEXT: firrtl.connect %out_b, %c1_ui9 : !firrtl.uint<9>, !firrtl.uint<9> @@ -1582,11 +1733,16 @@ firrtl.module @sub_cst_prop2(out %out_b: !firrtl.sint<9>) { // CHECK-LABEL: @sub_cst_prop3 // CHECK: %[[pad:.+]] = firrtl.pad %tmp_a, 5 // CHECK-NEXT: firrtl.connect %out_b, %[[pad]] +// CHECK: %[[neg:.+]] = firrtl.neg %tmp_a +// CHECK: %[[cast:.+]] = firrtl.asUInt %[[neg]] +// CHECK-NEXT: firrtl.connect %out_b, %[[cast]] firrtl.module @sub_cst_prop3(out %out_b: !firrtl.uint<5>) { %tmp_a = firrtl.wire : !firrtl.uint<4> %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> %sub = firrtl.sub %tmp_a, %invalid_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<5> firrtl.connect %out_b, %sub : !firrtl.uint<5>, !firrtl.uint<5> + %sub2 = firrtl.sub %invalid_ui4, %tmp_a : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<5> + firrtl.connect %out_b, %sub2 : !firrtl.uint<5>, !firrtl.uint<5> } // CHECK-LABEL: @mul_cst_prop1 @@ -1909,6 +2065,17 @@ firrtl.module @dshifts_to_ishifts(in %a_in: !firrtl.sint<58>, %c438_ui10 = firrtl.constant 438 : !firrtl.uint<10> %2 = firrtl.dshr %c_in, %c438_ui10 : (!firrtl.sint<58>, !firrtl.uint<10>) -> !firrtl.sint<58> firrtl.connect %c_out, %2 : !firrtl.sint<58>, !firrtl.sint<58> + + // CHECK: firrtl.connect %a_out, %a_in : !firrtl.sint<58>, !firrtl.sint<58> + %invalid_ui10 = firrtl.invalidvalue : !firrtl.uint<10> + %3 = firrtl.dshr %a_in, %invalid_ui10 : (!firrtl.sint<58>, !firrtl.uint<10>) -> !firrtl.sint<58> + firrtl.connect %a_out, %3 : !firrtl.sint<58>, !firrtl.sint<58> + + // CHECK: [[TMP:%.+]] = firrtl.pad %b_in, 23 : (!firrtl.uint<8>) -> !firrtl.uint<23> + // CHECK: firrtl.connect %b_out, [[TMP]] : !firrtl.uint<23>, !firrtl.uint<23> + %invalid_ui4 = firrtl.invalidvalue : !firrtl.uint<4> + %4 = firrtl.dshl %b_in, %invalid_ui4 : (!firrtl.uint<8>, !firrtl.uint<4>) -> !firrtl.uint<23> + firrtl.connect %b_out, %4 : !firrtl.uint<23>, !firrtl.uint<23> } // CHECK-LABEL: firrtl.module @constReg @@ -2081,4 +2248,15 @@ firrtl.module @ZeroWidthCat(out %a: !firrtl.uint<1>) { // CHECK-NEXT: firrtl.connect %a, %[[one]] } +// Issue mentioned in PR #2251 +// CHECK-LABEL: @Issue2251 +firrtl.module @Issue2251(out %o: !firrtl.sint<15>) { + // pad used to always return an unsigned constant + %invalid_si1 = firrtl.invalidvalue : !firrtl.sint<1> + %0 = firrtl.pad %invalid_si1, 15 : (!firrtl.sint<1>) -> !firrtl.sint<15> + firrtl.connect %o, %0 : !firrtl.sint<15>, !firrtl.sint<15> + // CHECK: %[[zero:.+]] = firrtl.constant 0 : !firrtl.sint<15> + // CHECK-NEXT: firrtl.connect %o, %[[zero]] +} + }