[FIRRTL] Change BundleType::BundleElement to a struct instead of pair

This makes the client side code a lot more clear, NFC.
This commit is contained in:
Chris Lattner 2020-11-30 12:46:16 -08:00
parent dbcdd62a2e
commit 83dfe1cc6c
6 changed files with 62 additions and 48 deletions

View File

@ -1,6 +1,6 @@
//===- FIRRTL/IR/Ops.h - FIRRTL dialect -------------------------*- C++ -*-===// //===- FIRRTL/IR/Types.h - FIRRTL Type System -------------------*- C++ -*-===//
// //
// This file defines an MLIR dialect for the FIRRTL IR. // This file defines type type system for the FIRRTL Dialect.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -218,7 +218,17 @@ public:
using Base::Base; using Base::Base;
// Each element of a bundle, which is a name and type. // Each element of a bundle, which is a name and type.
using BundleElement = std::pair<Identifier, FIRRTLType>; struct BundleElement {
Identifier name;
FIRRTLType type;
BundleElement(Identifier name, FIRRTLType type) : name(name), type(type) {}
bool operator==(const BundleElement &rhs) const {
return name == rhs.name && type == rhs.type;
}
bool operator!=(const BundleElement &rhs) const { return !operator==(rhs); }
};
static FIRRTLType get(ArrayRef<BundleElement> elements, MLIRContext *context); static FIRRTLType get(ArrayRef<BundleElement> elements, MLIRContext *context);

View File

@ -32,7 +32,7 @@ using ValueVectorList = std::vector<ValueVector>;
/// type of the data subfield. /// type of the data subfield.
static FIRRTLType buildBundleType(FIRRTLType dataType, bool isFlip, static FIRRTLType buildBundleType(FIRRTLType dataType, bool isFlip,
MLIRContext *context) { MLIRContext *context) {
using BundleElement = std::pair<Identifier, FIRRTLType>; using BundleElement = BundleType::BundleElement;
llvm::SmallVector<BundleElement, 3> elements; llvm::SmallVector<BundleElement, 3> elements;
// Add valid and ready subfield to the bundle. // Add valid and ready subfield to the bundle.
@ -390,8 +390,8 @@ static ValueVectorList extractSubfields(FModuleOp subModuleOp,
if (auto argType = arg.getType().dyn_cast<BundleType>()) { if (auto argType = arg.getType().dyn_cast<BundleType>()) {
// Extract all subfields of all bundle ports. // Extract all subfields of all bundle ports.
for (auto &element : argType.getElements()) { for (auto &element : argType.getElements()) {
StringRef elementName = element.first.strref(); StringRef elementName = element.name.strref();
FIRRTLType elementType = element.second; FIRRTLType elementType = element.type;
subfields.push_back(rewriter.create<SubfieldOp>( subfields.push_back(rewriter.create<SubfieldOp>(
insertLoc, elementType, arg, rewriter.getStringAttr(elementName))); insertLoc, elementType, arg, rewriter.getStringAttr(elementName)));
} }
@ -1092,7 +1092,7 @@ static void createInstOp(Operation *oldOp, FModuleOp subModuleOp,
FModuleOp topModuleOp, unsigned clockDomain, FModuleOp topModuleOp, unsigned clockDomain,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
rewriter.setInsertionPointAfter(oldOp); rewriter.setInsertionPointAfter(oldOp);
using BundleElement = std::pair<Identifier, FIRRTLType>; using BundleElement = BundleType::BundleElement;
llvm::SmallVector<BundleElement, 8> elements; llvm::SmallVector<BundleElement, 8> elements;
MLIRContext *context = subModuleOp.getContext(); MLIRContext *context = subModuleOp.getContext();
@ -1118,11 +1118,9 @@ static void createInstOp(Operation *oldOp, FModuleOp subModuleOp,
// the top-module. // the top-module.
unsigned portIndex = 0; unsigned portIndex = 0;
for (auto &element : instType.cast<BundleType>().getElements()) { for (auto &element : instType.cast<BundleType>().getElements()) {
Identifier elementName = element.first;
FIRRTLType elementType = element.second;
auto subfieldOp = rewriter.create<SubfieldOp>( auto subfieldOp = rewriter.create<SubfieldOp>(
oldOp->getLoc(), elementType, instanceOp, oldOp->getLoc(), element.type, instanceOp,
rewriter.getStringAttr(elementName.strref())); rewriter.getStringAttr(element.name.strref()));
unsigned numIns = oldOp->getNumOperands(); unsigned numIns = oldOp->getNumOperands();
unsigned numArgs = numIns + oldOp->getNumResults(); unsigned numArgs = numIns + oldOp->getNumResults();

View File

@ -52,9 +52,9 @@ static void flattenBundleTypes(FIRRTLType type, StringRef suffixSoFar,
// Construct the suffix to pass down. // Construct the suffix to pass down.
tmpSuffix.resize(suffixSoFar.size()); tmpSuffix.resize(suffixSoFar.size());
tmpSuffix.push_back('_'); tmpSuffix.push_back('_');
tmpSuffix.append(elt.first.strref()); tmpSuffix.append(elt.name.strref());
// Recursively process subelements. // Recursively process subelements.
flattenBundleTypes(elt.second, tmpSuffix, isFlipped, results); flattenBundleTypes(elt.type, tmpSuffix, isFlipped, results);
} }
} }
@ -209,13 +209,13 @@ void FIRRTLTypesLowering::visitDecl(InstanceOp op) {
for (auto element : originalBundleType.getElements()) { for (auto element : originalBundleType.getElements()) {
// Flatten any nested bundle types the usual way. // Flatten any nested bundle types the usual way.
SmallVector<FlatBundleFieldEntry, 8> fieldTypes; SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
flattenBundleTypes(element.second, element.first, isFlip, fieldTypes); flattenBundleTypes(element.type, element.name, isFlip, fieldTypes);
for (auto field : fieldTypes) { for (auto field : fieldTypes) {
// Store the flat type for the new bundle type. // Store the flat type for the new bundle type.
auto flatName = builder->getIdentifier(field.suffix); auto flatName = builder->getIdentifier(field.suffix);
auto flatType = field.getPortType(); auto flatType = field.getPortType();
auto newElement = BundleType::BundleElement(flatName, flatType); auto newElement = BundleType::BundleElement{flatName, flatType};
bundleElements.push_back(newElement); bundleElements.push_back(newElement);
} }
} }
@ -229,10 +229,10 @@ void FIRRTLTypesLowering::visitDecl(InstanceOp op) {
// Create new subfield ops for each field of the instance. // Create new subfield ops for each field of the instance.
for (auto element : bundleElements) { for (auto element : bundleElements) {
auto newSubfield = auto newSubfield =
builder->create<SubfieldOp>(element.second, newInstance, element.first); builder->create<SubfieldOp>(element.type, newInstance, element.name);
// Map the flattened suffix for the original bundle to the new value. // Map the flattened suffix for the original bundle to the new value.
setBundleLowering(op, element.first, newSubfield); setBundleLowering(op, element.name, newSubfield);
} }
// Remember to remove the original op. // Remember to remove the original op.
@ -396,5 +396,5 @@ void FIRRTLTypesLowering::getAllBundleLowerings(
BundleType bundleType = getCanonicalBundleType(value.getType()); BundleType bundleType = getCanonicalBundleType(value.getType());
assert(bundleType && "attempted to get bundle lowerings for non-bundle type"); assert(bundleType && "attempted to get bundle lowerings for non-bundle type");
for (auto element : bundleType.getElements()) for (auto element : bundleType.getElements())
results.push_back(getBundleLowering(value, element.first)); results.push_back(getBundleLowering(value, element.name));
} }

View File

@ -553,12 +553,12 @@ static LogicalResult verifyInstanceOp(InstanceOp &instance) {
.getType() .getType()
.cast<FIRRTLType>() .cast<FIRRTLType>()
.getPassiveType(); .getPassiveType();
if (bundleElements[i].second != expectedType) { if (bundleElements[i].type != expectedType) {
auto diag = instance.emitOpError() auto diag = instance.emitOpError()
<< "output bundle type must match module. In " << "output bundle type must match module. In "
"element " "element "
<< i << ", expected " << expectedType << ", but got " << i << ", expected " << expectedType << ", but got "
<< bundleElements[i].second << "."; << bundleElements[i].type << ".";
diag.attachNote(referencedFModule.getLoc()) diag.attachNote(referencedFModule.getLoc())
<< "original module declared here"; << "original module declared here";
@ -666,9 +666,9 @@ void MemOp::getPorts(
// Each entry in the bundle is a port. // Each entry in the bundle is a port.
for (auto elt : bundle.getElements()) { for (auto elt : bundle.getElements()) {
// Each port is a bundle. // Each port is a bundle.
auto kind = getMemPortKindFromType(elt.second); auto kind = getMemPortKindFromType(elt.type);
assert(kind.hasValue() && "unknown port type!"); assert(kind.hasValue() && "unknown port type!");
result.push_back({elt.first, kind.getValue()}); result.push_back({elt.name, kind.getValue()});
} }
} }
@ -692,7 +692,7 @@ FIRRTLType MemOp::getDataTypeOrNull() {
return {}; return {};
auto firstPort = bundle.getElements()[0]; auto firstPort = bundle.getElements()[0];
auto firstPortType = firstPort.second.getPassiveType().cast<BundleType>(); auto firstPortType = firstPort.type.getPassiveType().cast<BundleType>();
return firstPortType.getElementType("data"); return firstPortType.getElementType("data");
} }
@ -801,8 +801,8 @@ FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName,
Location loc) { Location loc) {
if (auto bundleType = inType.dyn_cast<BundleType>()) { if (auto bundleType = inType.dyn_cast<BundleType>()) {
for (auto &elt : bundleType.getElements()) { for (auto &elt : bundleType.getElements()) {
if (elt.first == fieldName) if (elt.name == fieldName)
return elt.second; return elt.type;
} }
} }

View File

@ -47,8 +47,8 @@ void FIRRTLType::print(raw_ostream &os) const {
os << "bundle<"; os << "bundle<";
llvm::interleaveComma(bundleType.getElements(), os, llvm::interleaveComma(bundleType.getElements(), os,
[&](BundleType::BundleElement element) { [&](BundleType::BundleElement element) {
os << element.first << ": "; os << element.name << ": ";
element.second.print(os); element.type.print(os);
}); });
os << '>'; os << '>';
}) })
@ -233,7 +233,7 @@ FIRRTLType FIRRTLType::getMaskType() {
SmallVector<BundleType::BundleElement, 4> newElements; SmallVector<BundleType::BundleElement, 4> newElements;
newElements.reserve(bundleType.getElements().size()); newElements.reserve(bundleType.getElements().size());
for (auto elt : bundleType.getElements()) for (auto elt : bundleType.getElements())
newElements.push_back({elt.first, elt.second.getMaskType()}); newElements.push_back({elt.name, elt.type.getMaskType()});
return BundleType::get(newElements, getContext()); return BundleType::get(newElements, getContext());
}) })
.Case<FVectorType>([](FVectorType vectorType) { .Case<FVectorType>([](FVectorType vectorType) {
@ -260,7 +260,7 @@ FIRRTLType FIRRTLType::getWidthlessType() {
SmallVector<BundleType::BundleElement, 4> newElements; SmallVector<BundleType::BundleElement, 4> newElements;
newElements.reserve(a.getElements().size()); newElements.reserve(a.getElements().size());
for (auto elt : a.getElements()) for (auto elt : a.getElements())
newElements.push_back({elt.first, elt.second.getWidthlessType()}); newElements.push_back({elt.name, elt.type.getWidthlessType()});
return BundleType::get(newElements, getContext()); return BundleType::get(newElements, getContext());
}) })
.Case<FVectorType>([](auto a) { .Case<FVectorType>([](auto a) {
@ -308,14 +308,10 @@ bool FIRRTLType::isResetType() {
/// canonicalizes flips in bundles, so only passive types can be compared here. /// canonicalizes flips in bundles, so only passive types can be compared here.
static bool areBundleElementsEquivalent(BundleType::BundleElement destElement, static bool areBundleElementsEquivalent(BundleType::BundleElement destElement,
BundleType::BundleElement srcElement) { BundleType::BundleElement srcElement) {
Identifier destElementName = std::get<0>(destElement); if (destElement.name != srcElement.name)
Identifier srcElementName = std::get<0>(srcElement);
if (destElementName != srcElementName)
return false; return false;
FIRRTLType destElementType = std::get<1>(destElement); return areTypesEquivalent(destElement.type, srcElement.type);
FIRRTLType srcElementType = std::get<1>(srcElement);
return areTypesEquivalent(destElementType, srcElementType);
} }
/// Returns whether the two types are equivalent. See the FIRRTL spec for the /// Returns whether the two types are equivalent. See the FIRRTL spec for the
@ -481,8 +477,8 @@ getFlippedBundleType(ArrayRef<BundleType::BundleElement> elements) {
SmallVector<BundleType::BundleElement, 16> flippedelements; SmallVector<BundleType::BundleElement, 16> flippedelements;
flippedelements.reserve(elements.size()); flippedelements.reserve(elements.size());
for (auto &elt : elements) for (auto &elt : elements)
flippedelements.push_back({elt.first, FlipType::get(elt.second)}); flippedelements.push_back({elt.name, FlipType::get(elt.type)});
return BundleType::get(flippedelements, elements[0].second.getContext()); return BundleType::get(flippedelements, elements[0].type.getContext());
} }
FIRRTLType FlipType::get(FIRRTLType element) { FIRRTLType FlipType::get(FIRRTLType element) {
@ -538,6 +534,14 @@ FIRRTLType FlipType::getElementType() { return getImpl()->element; }
// Bundle Type // Bundle Type
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace circt {
namespace firrtl {
llvm::hash_code hash_value(const BundleType::BundleElement &arg) {
return llvm::hash_value(arg.name) ^ mlir::hash_value(arg.type);
}
} // namespace firrtl
} // namespace circt
namespace circt { namespace circt {
namespace firrtl { namespace firrtl {
namespace detail { namespace detail {
@ -546,17 +550,19 @@ struct BundleTypeStorage : mlir::TypeStorage {
BundleTypeStorage(KeyTy elements) BundleTypeStorage(KeyTy elements)
: elements(elements.begin(), elements.end()) { : elements(elements.begin(), elements.end()) {
bool isPassive =
bool isPassive = llvm::all_of( llvm::all_of(elements, [](BundleType::BundleElement elt) -> bool {
elements, [](const BundleType::BundleElement &elt) -> bool { return elt.type.isPassive();
auto eltType = elt.second;
return eltType.isPassive();
}); });
passiveTypeInfo.setInt(isPassive); passiveTypeInfo.setInt(isPassive);
} }
bool operator==(const KeyTy &key) const { return key == KeyTy(elements); } bool operator==(const KeyTy &key) const { return key == KeyTy(elements); }
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static BundleTypeStorage *construct(TypeStorageAllocator &allocator, static BundleTypeStorage *construct(TypeStorageAllocator &allocator,
KeyTy key) { KeyTy key) {
return new (allocator.allocate<BundleTypeStorage>()) BundleTypeStorage(key); return new (allocator.allocate<BundleTypeStorage>()) BundleTypeStorage(key);
@ -579,7 +585,7 @@ FIRRTLType BundleType::get(ArrayRef<BundleElement> elements,
// the outer level. // the outer level.
if (!elements.empty() && if (!elements.empty() &&
llvm::all_of(elements, [&](const BundleElement &elt) -> bool { llvm::all_of(elements, [&](const BundleElement &elt) -> bool {
return elt.second.isa<FlipType>(); return elt.type.isa<FlipType>();
})) { })) {
return FlipType::get(getFlippedBundleType(elements)); return FlipType::get(getFlippedBundleType(elements));
} }
@ -608,7 +614,7 @@ FIRRTLType BundleType::getPassiveType() {
SmallVector<BundleType::BundleElement, 16> newElements; SmallVector<BundleType::BundleElement, 16> newElements;
newElements.reserve(impl->elements.size()); newElements.reserve(impl->elements.size());
for (auto &elt : impl->elements) { for (auto &elt : impl->elements) {
newElements.push_back({elt.first, elt.second.getPassiveType()}); newElements.push_back({elt.name, elt.type.getPassiveType()});
} }
auto passiveType = BundleType::get(newElements, getContext()); auto passiveType = BundleType::get(newElements, getContext());
@ -619,7 +625,7 @@ FIRRTLType BundleType::getPassiveType() {
/// Look up an element by name. This returns a BundleElement with. /// Look up an element by name. This returns a BundleElement with.
auto BundleType::getElement(StringRef name) -> Optional<BundleElement> { auto BundleType::getElement(StringRef name) -> Optional<BundleElement> {
for (const auto &element : getElements()) { for (const auto &element : getElements()) {
if (element.first == name) if (element.name == name)
return element; return element;
} }
return None; return None;
@ -627,7 +633,7 @@ auto BundleType::getElement(StringRef name) -> Optional<BundleElement> {
FIRRTLType BundleType::getElementType(StringRef name) { FIRRTLType BundleType::getElementType(StringRef name) {
auto element = getElement(name); auto element = getElement(name);
return element.hasValue() ? element.getValue().second : FIRRTLType(); return element.hasValue() ? element.getValue().type : FIRRTLType();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -145,9 +145,9 @@ static void flattenBundleTypes(Type type, StringRef suffixSoFar, bool isFlipped,
// Construct the suffix to pass down. // Construct the suffix to pass down.
tmpSuffix.resize(suffixSoFar.size()); tmpSuffix.resize(suffixSoFar.size());
tmpSuffix.push_back('_'); tmpSuffix.push_back('_');
tmpSuffix.append(elt.first.strref()); tmpSuffix.append(elt.name.strref());
// Recursively process subelements. // Recursively process subelements.
flattenBundleTypes(elt.second, tmpSuffix, isFlipped, results); flattenBundleTypes(elt.type, tmpSuffix, isFlipped, results);
} }
} }