[FIRRTL/LowerToHW] Support bundle type lowering (#2323)

* [FIRRTL/LowerToHW] Support missing bundle type lowering

This commit makes LowerToHW handle bundle values properly:
(i) bundle registers initialization and (ii) extension of
`getLoweredAndExtendValue` to accept struct/bundle values.
This commit is contained in:
Hideto Ueno 2021-12-12 15:02:40 +09:00 committed by GitHub
parent aeec3466fc
commit 2577910657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 17 deletions

View File

@ -173,7 +173,8 @@ def StructExtractOp : HWOp<"struct_extract", [NoSideEffect]> {
let printer = "return ::print$cppClass(p, *this);";
let builders = [
OpBuilder<(ins "Value":$input, "StructType::FieldInfo":$field)>
OpBuilder<(ins "Value":$input, "StructType::FieldInfo":$field)>,
OpBuilder<(ins "Value":$input, "StringAttr":$field)>
];
}

View File

@ -1136,8 +1136,8 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
std::function<void(void)> elseCtor = {});
void addIfProceduralBlock(Value cond, std::function<void(void)> thenCtor,
std::function<void(void)> elseCtor = {});
Value getExtOrTruncArrayValue(Value array, FIRRTLType sourceType,
FIRRTLType destType, bool allowTruncate);
Value getExtOrTruncAggregateValue(Value array, FIRRTLType sourceType,
FIRRTLType destType, bool allowTruncate);
// Create a temporary wire at the current insertion point, and try to
// eliminate it later as part of lowering post processing.
@ -1484,13 +1484,12 @@ Value FIRRTLLowering::getLoweredValue(Value value) {
return result;
}
/// Return the lowered array value whose type is converted into `destType`.
/// Return the lowered aggregate value whose type is converted into `destType`.
/// We have to care about the extension/truncation/signedness of each element.
/// If returns a null value for complex arrays such as arrays with bundles.
Value FIRRTLLowering::getExtOrTruncArrayValue(Value array,
FIRRTLType sourceType,
FIRRTLType destType,
bool allowTruncate) {
Value FIRRTLLowering::getExtOrTruncAggregateValue(Value array,
FIRRTLType sourceType,
FIRRTLType destType,
bool allowTruncate) {
SmallVector<Value> resultBuffer;
// Helper function to cast each element of array to dest type.
@ -1541,6 +1540,30 @@ Value FIRRTLLowering::getExtOrTruncArrayValue(Value array,
resultBuffer.push_back(array);
return success();
})
.Case<BundleType>([&](BundleType srcStructType) {
auto destStructType = destType.cast<BundleType>();
unsigned size = resultBuffer.size();
// TODO: We don't support partial connects for bundles for now.
if (destStructType.getNumElements() != srcStructType.getNumElements())
return failure();
for (auto elem : enumerate(destStructType.getElements())) {
auto structExtract =
builder.create<hw::StructExtractOp>(src, elem.value().name);
if (failed(recurse(structExtract,
srcStructType.getElementType(elem.index()),
destStructType.getElementType(elem.index()))))
return failure();
}
SmallVector<Value> temp(resultBuffer.begin() + size,
resultBuffer.end());
auto newStruct = builder.createOrFold<hw::StructCreateOp>(
lowerType(destStructType), temp);
resultBuffer.resize(size);
resultBuffer.push_back(newStruct);
return success();
})
.Case<IntType>([&](auto) {
if (auto result = cast(src, srcType, destType)) {
resultBuffer.push_back(result);
@ -1588,14 +1611,15 @@ Value FIRRTLLowering::getLoweredAndExtendedValue(Value value, Type destType) {
return getOrCreateIntConstant(destWidth, 0);
}
if (auto array = result.getType().dyn_cast<hw::ArrayType>()) {
// Aggregates values
if (result.getType().isa<hw::ArrayType, hw::StructType>()) {
// Types already match.
if (destType == value.getType())
return result;
return getExtOrTruncArrayValue(result, value.getType().cast<FIRRTLType>(),
destType.cast<FIRRTLType>(),
/* allowTruncate */ false);
return getExtOrTruncAggregateValue(
result, value.getType().cast<FIRRTLType>(), destType.cast<FIRRTLType>(),
/* allowTruncate */ false);
}
auto srcWidth = result.getType().cast<IntegerType>().getWidth();
@ -1647,14 +1671,15 @@ Value FIRRTLLowering::getLoweredAndExtOrTruncValue(Value value, Type destType) {
return getOrCreateIntConstant(destWidth, 0);
}
if (auto array = result.getType().dyn_cast<hw::ArrayType>()) {
// Aggregates values
if (result.getType().isa<hw::ArrayType, hw::StructType>()) {
// Types already match.
if (destType == value.getType())
return result;
return getExtOrTruncArrayValue(result, value.getType().cast<FIRRTLType>(),
destType.cast<FIRRTLType>(),
/* allowTruncate */ true);
return getExtOrTruncAggregateValue(
result, value.getType().cast<FIRRTLType>(), destType.cast<FIRRTLType>(),
/* allowTruncate */ true);
}
auto srcWidth = result.getType().cast<IntegerType>().getWidth();
@ -2181,6 +2206,13 @@ void FIRRTLLowering::initializeRegister(Value reg) {
recurse(arrayIndex, a.getElementType());
}
})
.Case<hw::StructType>([&](hw::StructType s) {
for (auto elem : s.getElements()) {
auto field =
builder.create<sv::StructFieldInOutOp>(reg, elem.name);
recurse(field, elem.type);
}
})
.Default([&](auto type) { emitRandomInit(reg, type); });
};
recurse(reg, type);

View File

@ -1469,6 +1469,13 @@ void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
build(builder, odsState, field.type, input, field.name);
}
void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
Value input, StringAttr fieldAttr) {
auto structType = input.getType().cast<StructType>();
auto resultType = structType.getFieldType(fieldAttr);
build(builder, odsState, resultType, input, fieldAttr);
}
//===----------------------------------------------------------------------===//
// StructInjectOp
//===----------------------------------------------------------------------===//

View File

@ -1466,4 +1466,43 @@ firrtl.circuit "Simple" attributes {annotations = [{class =
// CHECK-NEXT: sv.assign %3, %in : i1
// CHECK-NEXT: hw.output %0 : !hw.struct<a: !hw.struct<b: !hw.struct<c: i1>>>
}
// CHECK-LABEL: hw.module @initStruct
firrtl.module @initStruct(in %clock: !firrtl.clock) {
// CHECK: sv.ifdef.procedural "RANDOMIZE_REG_INIT" {
// CHECK-NEXT: %0 = sv.struct_field_inout %r["a"] : !hw.inout<struct<a: i1>>
// CHECK-NEXT: %RANDOM = sv.verbatim.expr.se "`RANDOM" : () -> i32 {symbols = []}
// CHECK-NEXT: %1 = comb.extract %RANDOM from 0 : (i32) -> i1
// CHECK-NEXT: sv.bpassign %0, %1 : i1
// CHECK-NEXT: }
%r = firrtl.reg %clock : !firrtl.bundle<a: uint<1>>
}
// CHECK-LABEL: hw.module @RegResetStructNarrow
firrtl.module @RegResetStructNarrow(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %init: !firrtl.bundle<a: uint<2>>) {
// CHECK: %0 = hw.struct_extract %init["a"] : !hw.struct<a: i2>
// CHECK-NEXT: %1 = comb.extract %0 from 0 : (i2) -> i1
// CHECK-NEXT: %2 = hw.struct_create (%1) : !hw.struct<a: i1>
// CHECK-NEXT: %reg = sv.reg : !hw.inout<struct<a: i1>>
// CHECK-NEXT: sv.always posedge %clock {
// CHECK-NEXT: sv.if %reset {
// CHECK-NEXT: sv.passign %reg, %2 : !hw.struct<a: i1>
// CHECK-NEXT: } else {
// CHECK-NEXT: }
// CHECK-NEXT: }
%reg = firrtl.regreset %clock, %reset, %init : !firrtl.uint<1>, !firrtl.bundle<a: uint<2>>, !firrtl.bundle<a: uint<1>>
}
// CHECK-LABEL: hw.module @BundleConnection
firrtl.module @BundleConnection(in %source: !firrtl.bundle<a: bundle<b: uint<1>>>, out %sink: !firrtl.bundle<a: bundle<b: uint<1>>>) {
%0 = firrtl.subfield %sink(0) : (!firrtl.bundle<a: bundle<b: uint<1>>>) -> !firrtl.bundle<b: uint<1>>
%1 = firrtl.subfield %source(0) : (!firrtl.bundle<a: bundle<b: uint<1>>>) -> !firrtl.bundle<b: uint<1>>
firrtl.connect %0, %1 : !firrtl.bundle<b: uint<1>>, !firrtl.bundle<b: uint<1>>
// CHECK: %.sink.output = sv.wire : !hw.inout<struct<a: !hw.struct<b: i1>>>
// CHECK-NEXT: %0 = sv.read_inout %.sink.output : !hw.inout<struct<a: !hw.struct<b: i1>>>
// CHECK-NEXT: %1 = sv.struct_field_inout %.sink.output["a"] : !hw.inout<struct<a: !hw.struct<b: i1>>>
// CHECK-NEXT: %2 = hw.struct_extract %source["a"] : !hw.struct<a: !hw.struct<b: i1>>
// CHECK-NEXT: sv.assign %1, %2 : !hw.struct<b: i1>
// CHECK-NEXT: hw.output %0 : !hw.struct<a: !hw.struct<b: i1>>
}
}