[FIRRTL] Handle flipped bundle types in LowerTypes (#263)

The canonical representation of a BundleType with every field flipped
is a FlipType that wraps the whole bundle. There were places where
LowerTypes assumed a BundleType, which caused crashes for
a FlipType wrapped around a BundleType. To address this:

* LowerTypes will default to disabled in firtool, and can be enabled by
passing the -enable-lower-types flag.
* The instance op verifier was updated to guarantee the result type is a
BundleType or a FlipType wrapped around a BundleType. 
* LowerTypes was updated to check for both scenarios in the places it
previously assumed it was working with a BundleType.
This commit is contained in:
mikeurbach 2020-11-20 11:55:07 -08:00 committed by GitHub
parent 5225d65e9f
commit 887513f8cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 169 additions and 41 deletions

View File

@ -56,6 +56,20 @@ static void flattenBundleTypes(FIRRTLType type, StringRef suffixSoFar,
}
}
// Helper to peel off the outer most flip type from a bundle that has all flips
// canonicalized to the outer level, or just return the bundle directly. For any
// other type, returns null.
static BundleType getCanonicalBundleType(Type originalType) {
BundleType originalBundleType;
if (auto flipType = originalType.dyn_cast<FlipType>())
originalBundleType = flipType.getElementType().dyn_cast<BundleType>();
else
originalBundleType = originalType.dyn_cast<BundleType>();
return originalBundleType;
}
//===----------------------------------------------------------------------===//
// Pass Infrastructure
//===----------------------------------------------------------------------===//
@ -180,15 +194,20 @@ void FIRRTLTypesLowering::lowerArg(BlockArgument arg, FIRRTLType type) {
// ensuring both lowerings are the same, we can process every module in the
// circuit in parallel, and every instance will have the correct ports.
void FIRRTLTypesLowering::visitDecl(InstanceOp op) {
// The instance's ports are represented as one value with a bundle type.
BundleType originalType = op.result().getType().cast<BundleType>();
// The instance's ports are represented as one value with a bundle type. Due
// to how bundles with flips are canonicalized, there may be an outer flip
// type that wraps the whole bundle.
Type originalType = op.result().getType();
bool isFlip = originalType.isa<FlipType>();
BundleType originalBundleType = getCanonicalBundleType(originalType);
assert(originalBundleType && "instance result was not a bundle type");
// Create a new, flat bundle type for the new result
SmallVector<BundleType::BundleElement, 8> bundleElements;
for (auto element : originalType.getElements()) {
for (auto element : originalBundleType.getElements()) {
// Flatten any nested bundle types the usual way.
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
flattenBundleTypes(element.second, element.first, false, fieldTypes);
flattenBundleTypes(element.second, element.first, isFlip, fieldTypes);
for (auto field : fieldTypes) {
// Store the flat type for the new bundle type.
@ -276,8 +295,13 @@ void FIRRTLTypesLowering::visitStmt(ConnectOp op) {
Value dest = op.dest();
Value src = op.src();
// Attempt to get the bundle types, potentially unwrapping an outer flip type
// that wraps the whole bundle.
BundleType destType = getCanonicalBundleType(dest.getType());
BundleType srcType = getCanonicalBundleType(src.getType());
// If we aren't connecting two bundles, there is nothing to do.
if (!dest.getType().isa<BundleType>() || !src.getType().isa<BundleType>())
if (!destType || !srcType)
return;
// Get the lowered values for each side.
@ -367,7 +391,8 @@ Value FIRRTLTypesLowering::getBundleLowering(Value oldValue,
// field.
void FIRRTLTypesLowering::getAllBundleLowerings(
Value value, SmallVectorImpl<Value> &results) {
BundleType bundleType = value.getType().cast<BundleType>();
BundleType bundleType = getCanonicalBundleType(value.getType());
assert(bundleType && "attempted to get bundle lowerings for non-bundle type");
for (auto element : bundleType.getElements())
results.push_back(getBundleLowering(value, element.first));
}

View File

@ -523,13 +523,18 @@ static LogicalResult verifyInstanceOp(InstanceOp &instance) {
return failure();
}
// Check that the result type is either a bundle type or a flip type that
// wraps a bundle type.
auto resultType = instance.getResult().getType().cast<FIRRTLType>();
if (!resultType.isa<BundleType>()) {
auto flipType = resultType.dyn_cast<FlipType>();
if (!flipType || !flipType.getElementType().isa<BundleType>())
return instance.emitOpError("has invalid result type of ") << resultType;
}
// Check that the result type is consistent with its module.
if (auto referencedFModule = dyn_cast<FModuleOp>(referencedModule)) {
auto bundle = instance.getResult()
.getType()
.cast<FIRRTLType>()
.getPassiveType()
.cast<BundleType>();
auto bundle = resultType.getPassiveType().cast<BundleType>();
auto bundleElements = bundle.getElements();
size_t e = bundleElements.size();

View File

@ -290,4 +290,32 @@ firrtl.module @X(%a : !firrtl.uint<4>) {
%0 = firrtl.bits %a 3 to 1 : (!firrtl.uint<4>) -> !firrtl.uint<2>
}
}
}
// -----
firrtl.circuit "TopModule" {
firrtl.module @SubModule(%a : !firrtl.uint<1>) {
}
firrtl.module @TopModule() {
// expected-error @+1 {{'firrtl.instance' op has invalid result type of '!firrtl.uint<1>'}}
%0 = firrtl.instance @SubModule : !firrtl.uint<1>
}
}
// -----
firrtl.circuit "TopModule" {
firrtl.module @SubModule(%a : !firrtl.uint<1>) {
}
firrtl.module @TopModule() {
// expected-error @+1 {{'firrtl.instance' op has invalid result type of '!firrtl.flip<uint<1>>'}}
%0 = firrtl.instance @SubModule : !firrtl.flip<uint<1>>
}
}

View File

@ -1,4 +1,4 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(lower-firrtl-types)' %s | FileCheck %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(lower-firrtl-types)' -split-input-file %s | FileCheck %s
firrtl.circuit "TopLevel" {
@ -30,32 +30,6 @@ firrtl.circuit "TopLevel" {
}
}
// CHECK-LABEL: firrtl.module @Recursive
// CHECK-SAME: %[[FLAT_ARG_1_NAME:arg_foo_bar_baz]]: [[FLAT_ARG_1_TYPE:!firrtl.uint<1>]]
// CHECK-SAME: %[[FLAT_ARG_2_NAME:arg_foo_qux]]: [[FLAT_ARG_2_TYPE:!firrtl.sint<64>]]
// CHECK-SAME: %[[OUT_1_NAME:out1]]: [[OUT_1_TYPE:!firrtl.flip<uint<1>>]]
// CHECK-SAME: %[[OUT_2_NAME:out2]]: [[OUT_2_TYPE:!firrtl.flip<sint<64>>]]
firrtl.module @Recursive(%arg: !firrtl.bundle<foo: bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>>,
%out1: !firrtl.flip<uint<1>>, %out2: !firrtl.flip<sint<64>>) {
// CHECK-NEXT: firrtl.connect %[[OUT_1_NAME]], %[[FLAT_ARG_1_NAME]] : [[OUT_1_TYPE]], [[FLAT_ARG_1_TYPE]]
// CHECK-NEXT: firrtl.connect %[[OUT_2_NAME]], %[[FLAT_ARG_2_NAME]] : [[OUT_2_TYPE]], [[FLAT_ARG_2_TYPE]]
%0 = firrtl.subfield %arg("foo") : (!firrtl.bundle<foo: bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>>) -> !firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>
%1 = firrtl.subfield %0("bar") : (!firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>) -> !firrtl.bundle<baz: uint<1>>
%2 = firrtl.subfield %1("baz") : (!firrtl.bundle<baz: uint<1>>) -> !firrtl.uint<1>
%3 = firrtl.subfield %0("qux") : (!firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>) -> !firrtl.sint<64>
firrtl.connect %out1, %2 : !firrtl.flip<uint<1>>, !firrtl.uint<1>
firrtl.connect %out2, %3 : !firrtl.flip<sint<64>>, !firrtl.sint<64>
}
// CHECK-LABEL: firrtl.module @Uniquification
// CHECK-SAME: %[[FLATTENED_ARG:a_b]]: [[FLATTENED_TYPE:!firrtl.uint<1>]],
// CHECK-NOT: %[[FLATTENED_ARG]]
// CHECK-SAME: %[[RENAMED_ARG:a_b.+]]: [[RENAMED_TYPE:!firrtl.uint<1>]] {firrtl.name = "[[FLATTENED_ARG]]"}
firrtl.module @Uniquification(%a: !firrtl.bundle<b: uint<1>>, %a_b: !firrtl.uint<1>) {
}
// CHECK-LABEL: firrtl.module @TopLevel
// CHECK-SAME: %[[SOURCE_VALID_NAME:source_valid]]: [[SOURCE_VALID_TYPE:!firrtl.uint<1>]]
// CHECK-SAME: %[[SOURCE_READY_NAME:source_ready]]: [[SOURCE_READY_TYPE:!firrtl.flip<uint<1>>]]
@ -97,3 +71,76 @@ firrtl.circuit "TopLevel" {
firrtl.connect %sink, %2 : !firrtl.bundle<valid: flip<uint<1>>, ready: uint<1>, data: flip<uint<64>>>, !firrtl.bundle<valid: uint<1>, ready: flip<uint<1>>, data: uint<64>>
}
}
// -----
firrtl.circuit "Recursive" {
// CHECK-LABEL: firrtl.module @Recursive
// CHECK-SAME: %[[FLAT_ARG_1_NAME:arg_foo_bar_baz]]: [[FLAT_ARG_1_TYPE:!firrtl.uint<1>]]
// CHECK-SAME: %[[FLAT_ARG_2_NAME:arg_foo_qux]]: [[FLAT_ARG_2_TYPE:!firrtl.sint<64>]]
// CHECK-SAME: %[[OUT_1_NAME:out1]]: [[OUT_1_TYPE:!firrtl.flip<uint<1>>]]
// CHECK-SAME: %[[OUT_2_NAME:out2]]: [[OUT_2_TYPE:!firrtl.flip<sint<64>>]]
firrtl.module @Recursive(%arg: !firrtl.bundle<foo: bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>>,
%out1: !firrtl.flip<uint<1>>, %out2: !firrtl.flip<sint<64>>) {
// CHECK-NEXT: firrtl.connect %[[OUT_1_NAME]], %[[FLAT_ARG_1_NAME]] : [[OUT_1_TYPE]], [[FLAT_ARG_1_TYPE]]
// CHECK-NEXT: firrtl.connect %[[OUT_2_NAME]], %[[FLAT_ARG_2_NAME]] : [[OUT_2_TYPE]], [[FLAT_ARG_2_TYPE]]
%0 = firrtl.subfield %arg("foo") : (!firrtl.bundle<foo: bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>>) -> !firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>
%1 = firrtl.subfield %0("bar") : (!firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>) -> !firrtl.bundle<baz: uint<1>>
%2 = firrtl.subfield %1("baz") : (!firrtl.bundle<baz: uint<1>>) -> !firrtl.uint<1>
%3 = firrtl.subfield %0("qux") : (!firrtl.bundle<bar: bundle<baz: uint<1>>, qux: sint<64>>) -> !firrtl.sint<64>
firrtl.connect %out1, %2 : !firrtl.flip<uint<1>>, !firrtl.uint<1>
firrtl.connect %out2, %3 : !firrtl.flip<sint<64>>, !firrtl.sint<64>
}
}
// -----
firrtl.circuit "Uniquification" {
// CHECK-LABEL: firrtl.module @Uniquification
// CHECK-SAME: %[[FLATTENED_ARG:a_b]]: [[FLATTENED_TYPE:!firrtl.uint<1>]],
// CHECK-NOT: %[[FLATTENED_ARG]]
// CHECK-SAME: %[[RENAMED_ARG:a_b.+]]: [[RENAMED_TYPE:!firrtl.uint<1>]] {firrtl.name = "[[FLATTENED_ARG]]"}
firrtl.module @Uniquification(%a: !firrtl.bundle<b: uint<1>>, %a_b: !firrtl.uint<1>) {
}
}
// -----
firrtl.circuit "InstanceFlipped" {
firrtl.module @SubModule(%clock: !firrtl.clock, %reset: !firrtl.uint<1>) {
}
// CHECK-LABEL: firrtl.module @InstanceFlipped
// CHECK: firrtl.instance @SubModule {{.*}} : !firrtl.flip<bundle<clock: clock, reset: uint<1>>>
firrtl.module @InstanceFlipped(%clock: !firrtl.clock, %reset: !firrtl.uint<1>) {
%int_bus = firrtl.instance @SubModule {name = "int_bus"} : !firrtl.flip<bundle<clock: clock, reset: uint<1>>>
// CHECK: firrtl.subfield %int_bus("clock") {{.*}} -> !firrtl.flip<clock>
%0 = firrtl.subfield %int_bus("clock") : (!firrtl.flip<bundle<clock: clock, reset: uint<1>>>) -> !firrtl.flip<clock>
firrtl.connect %0, %clock : !firrtl.flip<clock>, !firrtl.clock
// CHECK: firrtl.subfield %int_bus("reset") {{.*}} -> !firrtl.flip<uint<1>>
%1 = firrtl.subfield %int_bus("reset") : (!firrtl.flip<bundle<clock: clock, reset: uint<1>>>) -> !firrtl.flip<uint<1>>
firrtl.connect %1, %reset : !firrtl.flip<uint<1>>, !firrtl.uint<1>
}
}
// -----
firrtl.circuit "Top" {
// CHECK-LABEL: firrtl.module @Top
firrtl.module @Top(%in : !firrtl.bundle<a: uint<1>, b: uint<1>>,
%out : !firrtl.flip<bundle<a: uint<1>, b: uint<1>>>) {
// CHECK: firrtl.connect %out_a, %in_a : !firrtl.flip<uint<1>>, !firrtl.uint<1>
// CHECK: firrtl.connect %out_b, %in_b : !firrtl.flip<uint<1>>, !firrtl.uint<1>
firrtl.connect %out, %in : !firrtl.flip<bundle<a: uint<1>, b: uint<1>>>, !firrtl.bundle<a: uint<1>, b: uint<1>>
}
}

View File

@ -0,0 +1,17 @@
// RUN: firtool %s -format=mlir -lower-to-rtl | circt-opt -verify-diagnostics | FileCheck %s --check-prefix=LOWER
// RUN: firtool %s -format=mlir -lower-to-rtl -enable-lower-types | circt-opt -verify-diagnostics | FileCheck %s --check-prefix=LOWERTYPES
firrtl.circuit "Top" {
firrtl.module @Top(%in : !firrtl.bundle<a: uint<1>, b: uint<1>>,
%out : !firrtl.bundle<a: flip<uint<1>>, b: flip<uint<1>>>) {
firrtl.connect %out, %in : !firrtl.bundle<a: flip<uint<1>>, b: flip<uint<1>>>, !firrtl.bundle<a: uint<1>, b: uint<1>>
}
}
// LOWER-LABEL: module attributes {firrtl.mainModule = "Top"}
// expected-error: @+1 {{cannot lower this port type to RTL}}
// LOWERTYPES-LABEL: module attributes {firrtl.mainModule = "Top"}
// LOWERTYPES: %[[ARG0:.+]]: i1 {rtl.name = "in_a"}
// LOWERTYPES: %[[ARG1:.+]]: i1 {rtl.name = "in_b"}
// LOWERTYPES: rtl.output %[[ARG0]], %[[ARG1]]

View File

@ -54,6 +54,11 @@ static cl::opt<bool> disableOptimization("disable-opt",
static cl::opt<bool> lowerToRTL("lower-to-rtl",
cl::desc("run the lower-to-rtl pass"));
static cl::opt<bool>
enableLowerTypes("enable-lower-types",
cl::desc("run the lower-types pass within lower-to-rtl"),
cl::init(false));
static cl::opt<bool>
ignoreFIRLocations("ignore-fir-locators",
cl::desc("ignore the @info locations in the .fir file"),
@ -112,8 +117,9 @@ processBuffer(std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
// Run the lower-to-rtl pass if requested.
if (lowerToRTL) {
pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
firrtl::createLowerFIRRTLTypesPass());
if (enableLowerTypes)
pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
firrtl::createLowerFIRRTLTypesPass());
pm.addPass(firrtl::createLowerFIRRTLToRTLModulePass());
pm.nest<rtl::RTLModuleOp>().addPass(firrtl::createLowerFIRRTLToRTLPass());
}