[NFC] Rename confusing api.

.get() on a type should construct a new type.  BaseAlias was using this function to ALSO return the wrapped type.  Rename the accessor to not conflict with the factory.
This commit is contained in:
Andrew Lenharth 2024-01-24 16:13:16 -06:00
parent 6ef649eb43
commit 76cda20533
11 changed files with 43 additions and 41 deletions

View File

@ -233,7 +233,7 @@ def FEnumCreateOp : FIRRTLOp<"enumcreate"> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Return the name attribute of the accessed field. /// Return the name attribute of the accessed field.
StringAttr getFieldNameAttr() { StringAttr getFieldNameAttr() {
return getResult().getType().get().getElementNameAttr(getFieldIndex()); return getResult().getType().base().getElementNameAttr(getFieldIndex());
} }
/// Return the name of the accessed field. /// Return the name of the accessed field.
@ -337,7 +337,7 @@ def SubindexOp : FIRRTLExprOp<"subindex"> {
let firrtlExtraClassDeclaration = [{ let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field. /// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() { FieldRef getAccessedField() {
return FieldRef(getInput(), getInput().getType().get().getFieldID(getIndex())); return FieldRef(getInput(), getInput().getType().base().getFieldID(getIndex()));
} }
using InputType = FVectorType; using InputType = FVectorType;
}]; }];
@ -464,13 +464,13 @@ def SubtagOp : FIRRTLExprOp<"subtag"> {
let firrtlExtraClassDeclaration = [{ let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field. /// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() { FieldRef getAccessedField() {
return FieldRef(getInput(), getInput().getType().get() return FieldRef(getInput(), getInput().getType().base()
.getFieldID(getFieldIndex())); .getFieldID(getFieldIndex()));
} }
/// Return the name of the accessed field. /// Return the name of the accessed field.
StringAttr getFieldNameAttr() { StringAttr getFieldNameAttr() {
return getInput().getType().get().getElementNameAttr(getFieldIndex()); return getInput().getType().base().getElementNameAttr(getFieldIndex());
} }
/// Return the name of the accessed field. /// Return the name of the accessed field.

View File

@ -612,7 +612,7 @@ public:
// Support C++ implicit conversions to BaseTy. // Support C++ implicit conversions to BaseTy.
operator BaseTy() const { return circt::firrtl::type_cast<BaseTy>(*this); } operator BaseTy() const { return circt::firrtl::type_cast<BaseTy>(*this); }
BaseTy get() const { return circt::firrtl::type_cast<BaseTy>(*this); } BaseTy base() const { return circt::firrtl::type_cast<BaseTy>(*this); }
}; };
} // namespace firrtl } // namespace firrtl

View File

@ -1210,7 +1210,7 @@ static SmallVector<SubfieldOp> getAllFieldAccesses(Value structValue,
assert(isa<SubfieldOp>(op)); assert(isa<SubfieldOp>(op));
auto fieldAccess = cast<SubfieldOp>(op); auto fieldAccess = cast<SubfieldOp>(op);
auto elemIndex = auto elemIndex =
fieldAccess.getInput().getType().get().getElementIndex(field); fieldAccess.getInput().getType().base().getElementIndex(field);
if (elemIndex && *elemIndex == fieldAccess.getFieldIndex()) if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
accesses.push_back(fieldAccess); accesses.push_back(fieldAccess);
} }

View File

@ -443,7 +443,7 @@ OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
/// supersede any division with invalid or zero. Division of invalid by /// supersede any division with invalid or zero. Division of invalid by
/// invalid should be one. /// invalid should be one.
if (getLhs() == getRhs()) { if (getLhs() == getRhs()) {
auto width = getType().get().getWidthOrSentinel(); auto width = getType().base().getWidthOrSentinel();
if (width == -1) if (width == -1)
width = 2; width = 2;
// Only fold if we have at least 1 bit of width to represent the `1` value. // Only fold if we have at least 1 bit of width to represent the `1` value.
@ -522,8 +522,8 @@ OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp( return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift, *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt { [=](const APSInt &a, const APSInt &b) -> APInt {
return getType().get().isUnsigned() || !a.getBitWidth() ? a.lshr(b) return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
: a.ashr(b); : a.ashr(b);
}); });
} }
@ -623,7 +623,8 @@ OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
/// xor(x, x) -> 0 /// xor(x, x) -> 0
if (getLhs() == getRhs()) if (getLhs() == getRhs())
return getIntAttr( return getIntAttr(
getType(), APInt(std::max(getType().get().getWidthOrSentinel(), 0), 0)); getType(),
APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
return constFoldFIRRTLBinaryOp( return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal, *this, adaptor.getOperands(), BinOpKind::Normal,
@ -643,14 +644,14 @@ void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
} }
OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) { OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
bool isUnsigned = getLhs().getType().get().isUnsigned(); bool isUnsigned = getLhs().getType().base().isUnsigned();
// leq(x, x) -> 1 // leq(x, x) -> 1
if (getLhs() == getRhs()) if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 1)); return getIntAttr(getType(), APInt(1, 1));
// Comparison against constant outside type bounds. // Comparison against constant outside type bounds.
if (auto width = getLhs().getType().get().getWidth()) { if (auto width = getLhs().getType().base().getWidth()) {
if (auto rhsCst = getConstant(adaptor.getRhs())) { if (auto rhsCst = getConstant(adaptor.getRhs())) {
auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth()); auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
commonWidth = std::max(commonWidth, 1); commonWidth = std::max(commonWidth, 1);
@ -961,7 +962,7 @@ OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
// Be careful to only fold the cast into the constant if the size is known. // Be careful to only fold the cast into the constant if the size is known.
// Otherwise width inference may produce differently-sized constants if the // Otherwise width inference may produce differently-sized constants if the
// sign changes. // sign changes.
if (getType().get().hasWidth()) if (getType().base().hasWidth())
if (auto cst = getConstant(adaptor.getInput())) if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), *cst); return getIntAttr(getType(), *cst);
@ -981,7 +982,7 @@ OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
// Be careful to only fold the cast into the constant if the size is known. // Be careful to only fold the cast into the constant if the size is known.
// Otherwise width inference may produce differently-sized constants if the // Otherwise width inference may produce differently-sized constants if the
// sign changes. // sign changes.
if (getType().get().hasWidth()) if (getType().base().hasWidth())
if (auto cst = getConstant(adaptor.getInput())) if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), *cst); return getIntAttr(getType(), *cst);
@ -1023,7 +1024,7 @@ OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
// Signed to signed is a noop, unsigned operands prepend a zero bit. // Signed to signed is a noop, unsigned operands prepend a zero bit.
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(), if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel())) getType().base().getWidthOrSentinel()))
return getIntAttr(getType(), *cst); return getIntAttr(getType(), *cst);
return {}; return {};
@ -1041,7 +1042,7 @@ OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
// FIRRTL negate always adds a bit. // FIRRTL negate always adds a bit.
// -x ---> 0-sext(x) or 0-zext(x) // -x ---> 0-sext(x) or 0-zext(x)
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(), if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel())) getType().base().getWidthOrSentinel()))
return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst); return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
return {}; return {};
@ -1052,7 +1053,7 @@ OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
return {}; return {};
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(), if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel())) getType().base().getWidthOrSentinel()))
return getIntAttr(getType(), ~*cst); return getIntAttr(getType(), ~*cst);
return {}; return {};
@ -1471,14 +1472,14 @@ OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
return input; return input;
// Need to know the input width. // Need to know the input width.
auto inputType = input.getType().get(); auto inputType = input.getType().base();
int32_t width = inputType.getWidthOrSentinel(); int32_t width = inputType.getWidthOrSentinel();
if (width == -1) if (width == -1)
return {}; return {};
// Constant fold. // Constant fold.
if (auto cst = getConstant(adaptor.getInput())) { if (auto cst = getConstant(adaptor.getInput())) {
auto destWidth = getType().get().getWidthOrSentinel(); auto destWidth = getType().base().getWidthOrSentinel();
if (destWidth == -1) if (destWidth == -1)
return {}; return {};
@ -1545,7 +1546,7 @@ OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
} }
LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) { LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel(); auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
if (inputWidth <= 0) if (inputWidth <= 0)
return failure(); return failure();
@ -1553,7 +1554,7 @@ LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
unsigned shiftAmount = op.getAmount(); unsigned shiftAmount = op.getAmount();
if (int(shiftAmount) >= inputWidth) { if (int(shiftAmount) >= inputWidth) {
// shift(x, 32) => 0 when x has 32 bits. This is handled by fold(). // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
if (op.getType().get().isUnsigned()) if (op.getType().base().isUnsigned())
return failure(); return failure();
// Shifting a signed value by the full width is actually taking the // Shifting a signed value by the full width is actually taking the
@ -1568,7 +1569,7 @@ LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op, LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel(); auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
if (inputWidth <= 0) if (inputWidth <= 0)
return failure(); return failure();
@ -1584,7 +1585,7 @@ OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
if (hasKnownWidthIntTypes(*this)) if (hasKnownWidthIntTypes(*this))
if (auto cst = getConstant(adaptor.getInput())) { if (auto cst = getConstant(adaptor.getInput())) {
int shiftAmount = int shiftAmount =
getInput().getType().get().getWidthOrSentinel() - getAmount(); getInput().getType().base().getWidthOrSentinel() - getAmount();
return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount())); return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
} }
@ -1595,13 +1596,13 @@ OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
if (hasKnownWidthIntTypes(*this)) if (hasKnownWidthIntTypes(*this))
if (auto cst = getConstant(adaptor.getInput())) if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), return getIntAttr(getType(),
cst->trunc(getType().get().getWidthOrSentinel())); cst->trunc(getType().base().getWidthOrSentinel()));
return {}; return {};
} }
LogicalResult TailPrimOp::canonicalize(TailPrimOp op, LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel(); auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
if (inputWidth <= 0) if (inputWidth <= 0)
return failure(); return failure();

View File

@ -3209,7 +3209,7 @@ static LogicalResult checkConnectConditionality(FConnectLike connect) {
.Case<SubaccessOp>([&](SubaccessOp op) { .Case<SubaccessOp>([&](SubaccessOp op) {
if (op.getInput() if (op.getInput()
.getType() .getType()
.get() .base()
.getElementTypePreservingConst() .getElementTypePreservingConst()
.isConst()) .isConst())
originalFieldType = originalFieldType.getConstType(true); originalFieldType = originalFieldType.getConstType(true);
@ -4025,7 +4025,7 @@ ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult IsTagOp::verify() { LogicalResult IsTagOp::verify() {
if (getFieldIndex() >= getInput().getType().get().getNumElements()) if (getFieldIndex() >= getInput().getType().base().getNumElements())
return emitOpError("element index is greater than the number of fields in " return emitOpError("element index is greater than the number of fields in "
"the bundle type"); "the bundle type");
return success(); return success();
@ -4214,7 +4214,7 @@ LogicalResult OpenSubfieldOp::verify() {
} }
LogicalResult SubtagOp::verify() { LogicalResult SubtagOp::verify() {
if (getFieldIndex() >= getInput().getType().get().getNumElements()) if (getFieldIndex() >= getInput().getType().base().getNumElements())
return emitOpError("subfield element index is greater than the number " return emitOpError("subfield element index is greater than the number "
"of fields in the bundle type"); "of fields in the bundle type");
return success(); return success();

View File

@ -118,17 +118,17 @@ public:
.Case<SubindexOp>([&](SubindexOp sub) { .Case<SubindexOp>([&](SubindexOp sub) {
recordValueRefersToFieldRef( recordValueRefersToFieldRef(
sub.getInput(), sub.getInput(),
sub.getInput().getType().get().getFieldID(sub.getIndex()), sub.getInput().getType().base().getFieldID(sub.getIndex()),
sub.getResult()); sub.getResult());
}) })
.Case<SubfieldOp>([&](SubfieldOp sub) { .Case<SubfieldOp>([&](SubfieldOp sub) {
recordValueRefersToFieldRef( recordValueRefersToFieldRef(
sub.getInput(), sub.getInput(),
sub.getInput().getType().get().getFieldID(sub.getFieldIndex()), sub.getInput().getType().base().getFieldID(sub.getFieldIndex()),
sub.getResult()); sub.getResult());
}) })
.Case<SubaccessOp>([&](SubaccessOp sub) { .Case<SubaccessOp>([&](SubaccessOp sub) {
auto vecType = sub.getInput().getType().get(); auto vecType = sub.getInput().getType().base();
auto input = sub.getInput(); auto input = sub.getInput();
auto result = sub.getResult(); auto result = sub.getResult();
// Flatten the subaccess. The result can refer to any of the // Flatten the subaccess. The result can refer to any of the

View File

@ -80,7 +80,7 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
for (Operation *u : portVal.getUsers()) for (Operation *u : portVal.getUsers())
if (auto sf = dyn_cast<SubfieldOp>(u)) { if (auto sf = dyn_cast<SubfieldOp>(u)) {
// Get the field name. // Get the field name.
auto fName = sf.getInput().getType().get().getElementName( auto fName = sf.getInput().getType().base().getElementName(
sf.getFieldIndex()); sf.getFieldIndex());
// If this is the enable field, record the product terms(the And // If this is the enable field, record the product terms(the And
// expression tree). // expression tree).
@ -188,7 +188,7 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
// replace them. // replace them.
for (Operation *u : portVal.getUsers()) for (Operation *u : portVal.getUsers())
if (auto sf = dyn_cast<SubfieldOp>(u)) { if (auto sf = dyn_cast<SubfieldOp>(u)) {
StringRef fName = sf.getInput().getType().get().getElementName( StringRef fName = sf.getInput().getType().base().getElementName(
sf.getFieldIndex()); sf.getFieldIndex());
Value repl; Value repl;
if (isReadPort) if (isReadPort)
@ -328,7 +328,7 @@ private:
if (auto sf = dyn_cast<SubfieldOp>(u)) { if (auto sf = dyn_cast<SubfieldOp>(u)) {
// Get the field name. // Get the field name.
auto fName = auto fName =
sf.getInput().getType().get().getElementName(sf.getFieldIndex()); sf.getInput().getType().base().getElementName(sf.getFieldIndex());
// Check if this is the mask field. // Check if this is the mask field.
if (fName.contains("mask")) { if (fName.contains("mask")) {
// Already 1 bit, nothing to do. // Already 1 bit, nothing to do.
@ -379,7 +379,7 @@ private:
auto sf = auto sf =
builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex()); builder.create<SubfieldOp>(newPortVal, oldRes.getFieldIndex());
auto fName = auto fName =
sf.getInput().getType().get().getElementName(sf.getFieldIndex()); sf.getInput().getType().base().getElementName(sf.getFieldIndex());
// Replace all mask fields with a one bit constant 1. // Replace all mask fields with a one bit constant 1.
// Replace all other fields with the new port. // Replace all other fields with the new port.
if (fName.contains("mask")) { if (fName.contains("mask")) {

View File

@ -1394,7 +1394,7 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
// If the constant has a known width, use that. Otherwise pick the // If the constant has a known width, use that. Otherwise pick the
// smallest number of bits necessary to represent the constant. // smallest number of bits necessary to represent the constant.
Expr *e; Expr *e;
if (auto width = op.getType().get().getWidth()) if (auto width = op.getType().base().getWidth())
e = solver.known(*width); e = solver.known(*width);
else { else {
auto v = op.getValue(); auto v = op.getValue();
@ -1496,7 +1496,7 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
.Case<DivPrimOp>([&](auto op) { .Case<DivPrimOp>([&](auto op) {
auto lhs = getExpr(op.getLhs()); auto lhs = getExpr(op.getLhs());
Expr *e; Expr *e;
if (op.getType().get().isSigned()) { if (op.getType().base().isSigned()) {
e = solver.add(lhs, solver.known(1)); e = solver.add(lhs, solver.known(1));
} else { } else {
e = lhs; e = lhs;
@ -1542,7 +1542,7 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
}) })
.Case<CvtPrimOp>([&](auto op) { .Case<CvtPrimOp>([&](auto op) {
auto input = getExpr(op.getInput()); auto input = getExpr(op.getInput());
auto e = op.getInput().getType().get().isSigned() auto e = op.getInput().getType().base().isSigned()
? input ? input
: solver.add(input, solver.known(1)); : solver.add(input, solver.known(1));
setExpr(op.getResult(), e); setExpr(op.getResult(), e);

View File

@ -325,7 +325,7 @@ static SmallVector<SubfieldOp> getAllFieldAccesses(Value structValue,
assert(isa<SubfieldOp>(op)); assert(isa<SubfieldOp>(op));
auto fieldAccess = cast<SubfieldOp>(op); auto fieldAccess = cast<SubfieldOp>(op);
auto elemIndex = auto elemIndex =
fieldAccess.getInput().getType().get().getElementIndex(field); fieldAccess.getInput().getType().base().getElementIndex(field);
if (elemIndex && *elemIndex == fieldAccess.getFieldIndex()) if (elemIndex && *elemIndex == fieldAccess.getFieldIndex())
accesses.push_back(fieldAccess); accesses.push_back(fieldAccess);
} }

View File

@ -234,7 +234,7 @@ static bool isNotSubAccess(Operation *op) {
return true; return true;
ConstantOp arg = ConstantOp arg =
llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp()); llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
return arg && sao.getInput().getType().get().getNumElements() != 0; return arg && sao.getInput().getType().base().getNumElements() != 0;
} }
/// Look through and collect subfields leading to a subaccess. /// Look through and collect subfields leading to a subaccess.

View File

@ -846,7 +846,8 @@ LogicalResult Visitor::visitExpr(VectorCreateOp op) {
} }
auto value = sinkVecDimIntoOperands( auto value = sinkVecDimIntoOperands(
builder, convertType(oldType.get().getElementType()), convertedOldFields); builder, convertType(oldType.base().getElementType()),
convertedOldFields);
valueMap[op.getResult()] = value; valueMap[op.getResult()] = value;
toDelete.push_back(op); toDelete.push_back(op);
return success(); return success();