From 79f45c2b68f17fbe9fa9078843b42f726f48b678 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Mon, 18 Jul 2022 19:14:50 +0900 Subject: [PATCH] [LowerTypes] Add flags to refine aggregate preservation (#3490) This PR modifies LowerTypes to specify types preserved by aggregate preservation, i.e. `-preserve-aggregate={none, 1d-vec, vec, all}`. Implementation wise, `peelType` takes the enum and selectively lower types. --- include/circt/Dialect/FIRRTL/Passes.h | 24 +++- include/circt/Dialect/FIRRTL/Passes.td | 13 +- lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp | 122 +++++++++++------- .../Dialect/FIRRTL/lower-types-aggregate.mlir | 11 +- test/Dialect/FIRRTL/lower-types.mlir | 2 +- tools/firtool/firtool.cpp | 21 ++- 6 files changed, 132 insertions(+), 61 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/Passes.h b/include/circt/Dialect/FIRRTL/Passes.h index 6fa41a3fd3..c14467ecd2 100644 --- a/include/circt/Dialect/FIRRTL/Passes.h +++ b/include/circt/Dialect/FIRRTL/Passes.h @@ -28,9 +28,27 @@ std::unique_ptr createLowerFIRRTLAnnotationsPass(bool ignoreUnhandledAnnotations = false, bool ignoreClasslessAnnotations = false); -std::unique_ptr -createLowerFIRRTLTypesPass(bool preserveAggregate = false, - bool preservePublicTypes = true); +/// Configure which aggregate values will be preserved by the LowerTypes pass. +namespace PreserveAggregate { +enum PreserveMode { + /// Don't preserve aggregate at all. This has been default behaivor and + /// compatible with SFC. + None, + + /// Preserve only 1d vectors of ground type (e.g. UInt<2>[3]). + OneDimVec, + + /// Preserve only vectors (e.g. UInt<2>[3][3]). + Vec, + + /// Preserve all aggregate values. + All, +}; +} + +std::unique_ptr createLowerFIRRTLTypesPass( + PreserveAggregate::PreserveMode mode = PreserveAggregate::None, + bool preservePublicTypes = true); std::unique_ptr createLowerBundleVectorTypesPass(); diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 9b08711c68..3be3c71b4d 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -56,11 +56,18 @@ def LowerFIRRTLTypes : Pass<"firrtl-lower-types", "firrtl::CircuitOp"> { let options = [ Option<"flattenAggregateMemData", "flatten-mem", "bool", "false", "Concat all elements of the aggregate data into a single element.">, - Option<"preserveAggregate", "preserve-aggregate", "bool", "false", - "Preserve passive aggregate types in the module.">, Option<"preservePublicTypes", "preserve-public-types", "bool", "true", "Force to lower ports of toplevel and external modules even" - "when aggregate preservation mode."> + "when aggregate preservation mode.">, + Option<"preserveAggregate", "preserve-aggregate", "PreserveAggregate::PreserveMode", + "PreserveAggregate::None", + "Specify aggregate preservation mode", + [{::llvm::cl::values( + clEnumValN(PreserveAggregate::None, "none", "Preserve no aggregate"), + clEnumValN(PreserveAggregate::OneDimVec, "1d-vec", "Preserve 1d vectors"), + clEnumValN(PreserveAggregate::Vec, "vec", "Preserve vectors"), + clEnumValN(PreserveAggregate::All, "all", "Preserve vectors and bundles") + )}]> ]; let dependentDialects = ["hw::HWDialect"]; } diff --git a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp index f12f32c6e4..12f43a630a 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp @@ -100,22 +100,59 @@ static bool hasZeroBitWidth(FIRRTLType type) { }); } -/// Return true if we can preserve the aggregate type. We can a preserve the -/// type iff (i) the type is not passive, (ii) the type doesn't contain analog -/// and (iii) type don't contain zero bitwidth. -static bool isPreservableAggregateType(Type type) { +/// Return true if the type is a 1d vector type or ground type. +static bool isOneDimVectorType(FIRRTLType type) { + return TypeSwitch(type) + .Case([&](auto bundle) { return false; }) + .Case([&](FVectorType vector) { + return vector.getElementType().isGround(); + }) + .Default([](auto groundType) { return true; }); +} + +/// Return true if the type has a bundle type as subtype. +static bool containsBundleType(FIRRTLType type) { + return TypeSwitch(type) + .Case([&](auto bundle) { return true; }) + .Case([&](FVectorType vector) { + return containsBundleType(vector.getElementType()); + }) + .Default([](auto groundType) { return false; }); +} + +/// Return true if we can preserve the type. +static bool isPreservableAggregateType(Type type, + PreserveAggregate::PreserveMode mode) { + // Return false if no aggregate value is preserved. + if (mode == PreserveAggregate::None) + return false; + auto firrtlType = type.cast(); - return firrtlType.isPassive() && !firrtlType.containsAnalog() && - !hasZeroBitWidth(firrtlType); + // We can a preserve the type iff (i) the type is not passive, (ii) the type + // doesn't contain analog and (iii) type don't contain zero bitwidth. + if (!firrtlType.isPassive() || firrtlType.containsAnalog() || + hasZeroBitWidth(firrtlType)) + return false; + + switch (mode) { + case PreserveAggregate::All: + return true; + case PreserveAggregate::OneDimVec: + return isOneDimVectorType(firrtlType); + case PreserveAggregate::Vec: + return !containsBundleType(firrtlType); + default: + llvm_unreachable("unexpected mode"); + } } /// Peel one layer of an aggregate type into its components. Type may be /// complex, but empty, in which case fields is empty, but the return is true. static bool peelType(Type type, SmallVectorImpl &fields, - bool allowedToPreserveAggregate = false) { + PreserveAggregate::PreserveMode mode) { // If the aggregate preservation is enabled and the type is preservable, // then just return. - if (allowedToPreserveAggregate && isPreservableAggregateType(type)) + if (isPreservableAggregateType(type, mode)) return false; return TypeSwitch(type) @@ -298,10 +335,11 @@ struct AttrCache { // not. struct TypeLoweringVisitor : public FIRRTLVisitor { - TypeLoweringVisitor(MLIRContext *context, bool preserveAggregate, + TypeLoweringVisitor(MLIRContext *context, + PreserveAggregate::PreserveMode preserveAggregate, bool preservePublicTypes, SymbolTable &symTbl, const AttrCache &cache) - : context(context), preserveAggregate(preserveAggregate), + : context(context), aggregatePreservationMode(preserveAggregate), preservePublicTypes(preservePublicTypes), symTbl(symTbl), cache(cache) { } using FIRRTLVisitor::visitDecl; @@ -359,7 +397,8 @@ private: FIRRTLType srcType, FlatBundleFieldEntry field, bool &needsSym, StringRef sym); - bool isModuleAllowedToPreserveAggregate(FModuleLike moduleLike); + PreserveAggregate::PreserveMode + getPreservatinoModeForModule(FModuleLike moduleLike); Value getSubWhatever(Value val, size_t index); size_t uniqueIdx = 0; @@ -370,9 +409,8 @@ private: MLIRContext *context; - /// Not to lower passive aggregate types as much as possible if this flag is - /// enabled. - bool preserveAggregate; + /// Aggregate preservation mode. + PreserveAggregate::PreserveMode aggregatePreservationMode; /// Exteranal modules and toplevel modules should have lowered types if this /// flag is enabled. @@ -416,23 +454,16 @@ private: }; } // namespace -/// Return true if we can preserve the arguments of the given module. -/// Exteranal modules and toplevel modules are sometimes assumed to have lowered -/// types. -bool TypeLoweringVisitor::isModuleAllowedToPreserveAggregate( - FModuleLike module) { +/// Return aggregate preservation mode for the module. If the module has a +/// public linkage, then it is not allowed to preserve aggregate values on ports +/// unless `preservePublicTypes` flag is disabled. +PreserveAggregate::PreserveMode +TypeLoweringVisitor::getPreservatinoModeForModule(FModuleLike module) { - if (!preserveAggregate) - return false; - - // If it is not forced to lower toplevel and external modules, it's ok to - // preserve. - if (!preservePublicTypes) - return true; - - if (isa(module)) - return false; - return !cast(*module).isPublic(); + if (aggregatePreservationMode != PreserveAggregate::None && + preservePublicTypes && cast(*module).isPublic()) + return PreserveAggregate::None; + return aggregatePreservationMode; } Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) { @@ -553,7 +584,7 @@ bool TypeLoweringVisitor::lowerProducer( auto srcType = op->getResult(0).getType().cast(); SmallVector fieldTypes; - if (!peelType(srcType, fieldTypes, preserveAggregate)) + if (!peelType(srcType, fieldTypes, aggregatePreservationMode)) return false; SmallVector lowered; @@ -715,8 +746,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex, // Flatten any bundle types. SmallVector fieldTypes; auto srcType = newArgs[argIndex].type.cast(); - if (!peelType(srcType, fieldTypes, - isModuleAllowedToPreserveAggregate(module))) + if (!peelType(srcType, fieldTypes, getPreservatinoModeForModule(module))) return false; for (const auto &field : llvm::enumerate(fieldTypes)) { @@ -773,8 +803,7 @@ bool TypeLoweringVisitor::visitStmt(ConnectOp op) { SmallVector fields; // We have to expand connections even if the aggregate preservation is true. - if (!peelType(op.getDest().getType(), fields, - /* allowedToPreserveAggregate */ false)) + if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None)) return false; // Loop over the leaf aggregates. @@ -797,8 +826,7 @@ bool TypeLoweringVisitor::visitStmt(StrictConnectOp op) { SmallVector fields; // We have to expand connections even if the aggregate preservation is true. - if (!peelType(op.getDest().getType(), fields, - /* allowedToPreserveAggregate */ false)) + if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None)) return false; // Loop over the leaf aggregates. @@ -833,7 +861,7 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) { SmallVector fields; // MemOp should have ground types so we can't preserve aggregates. - if (!peelType(op.getDataType(), fields, false)) + if (!peelType(op.getDataType(), fields, PreserveAggregate::None)) return false; SmallVector newMemories; @@ -1112,8 +1140,7 @@ bool TypeLoweringVisitor::visitExpr(BitCastOp op) { // UInt type result. That is, first bitcast the aggregate type to a UInt. // Attempt to get the bundle types. SmallVector fields; - if (peelType(op.getInput().getType(), fields, - /* allowedToPreserveAggregate */ false)) { + if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) { size_t uptoBits = 0; // Loop over the leaf aggregates and concat each of them to get a UInt. // Bitcast the fields to handle nested aggregate types. @@ -1176,8 +1203,8 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector newDirs; SmallVector newNames; SmallVector newPortAnno; - bool allowedToPreserveAggregate = - isModuleAllowedToPreserveAggregate(op.getReferencedModule(symTbl)); + PreserveAggregate::PreserveMode mode = + getPreservatinoModeForModule(op.getReferencedModule(symTbl)); endFields.push_back(0); bool needsSymbol = false; @@ -1186,7 +1213,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { // Flatten any nested bundle types the usual way. SmallVector fieldTypes; - if (!peelType(srcType, fieldTypes, allowedToPreserveAggregate)) { + if (!peelType(srcType, fieldTypes, mode)) { newDirs.push_back(op.getPortDirection(i)); newNames.push_back(op.getPortName(i)); resultTypes.push_back(srcType); @@ -1294,7 +1321,9 @@ bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) { namespace { struct LowerTypesPass : public LowerFIRRTLTypesBase { - LowerTypesPass(bool preserveAggregateFlag, bool preservePublicTypesFlag) { + LowerTypesPass( + circt::firrtl::PreserveAggregate::PreserveMode preserveAggregateFlag, + bool preservePublicTypesFlag) { preserveAggregate = preserveAggregateFlag; preservePublicTypes = preservePublicTypesFlag; } @@ -1463,9 +1492,8 @@ void LowerTypesPass::runOnOperation() { /// This is the pass constructor. std::unique_ptr -circt::firrtl::createLowerFIRRTLTypesPass(bool preserveAggregate, +circt::firrtl::createLowerFIRRTLTypesPass(PreserveAggregate::PreserveMode mode, bool preservePublicTypes) { - return std::make_unique(preserveAggregate, - preservePublicTypes); + return std::make_unique(mode, preservePublicTypes); } diff --git a/test/Dialect/FIRRTL/lower-types-aggregate.mlir b/test/Dialect/FIRRTL/lower-types-aggregate.mlir index 7236c68651..0041738029 100644 --- a/test/Dialect/FIRRTL/lower-types-aggregate.mlir +++ b/test/Dialect/FIRRTL/lower-types-aggregate.mlir @@ -1,5 +1,7 @@ -// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck %s -// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s +// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck %s +// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s +// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=vec})' %s | FileCheck --check-prefix=VEC %s +// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=1d-vec})' %s | FileCheck --check-prefix=1D_VEC %s firrtl.circuit "TopLevel" { // CHECK-LABEL: firrtl.extmodule @External(in source_valid: !firrtl.uint<1>) @@ -10,4 +12,9 @@ firrtl.circuit "TopLevel" { firrtl.module @TopLevel(in %source: !firrtl.bundle>, out %sink: !firrtl.bundle>) { } + // CHECK: @Foo(in %a: !firrtl.bundle, 2>, 2>>) + // VEC: @Foo(in %a_a: !firrtl.vector, 2>, 2>) + // 1D_VEC: @Foo(in %a_a_0: !firrtl.vector, 2>, in %a_a_1: !firrtl.vector, 2>) + firrtl.module private @Foo(in %a: !firrtl.bundle, 2>, 2>>) { + } } diff --git a/test/Dialect/FIRRTL/lower-types.mlir b/test/Dialect/FIRRTL/lower-types.mlir index d49b059cd7..ba98e205d9 100644 --- a/test/Dialect/FIRRTL/lower-types.mlir +++ b/test/Dialect/FIRRTL/lower-types.mlir @@ -1,5 +1,5 @@ // RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types)' %s | FileCheck --check-prefixes=CHECK,COMMON %s -// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s +// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s firrtl.circuit "TopLevel" { diff --git a/tools/firtool/firtool.cpp b/tools/firtool/firtool.cpp index 23f428ab73..0e1e0f4096 100644 --- a/tools/firtool/firtool.cpp +++ b/tools/firtool/firtool.cpp @@ -128,10 +128,20 @@ static cl::opt replSeqMem( "replace the seq mem for macro replacement and emit relevant metadata"), cl::init(false), cl::cat(mainCategory)); -static cl::opt - preserveAggregate("preserve-aggregate", - cl::desc("preserve aggregate types in lower types"), - cl::init(false), cl::cat(mainCategory)); +static cl::opt + preserveAggregate( + "preserve-aggregate", cl::desc("Specify input file format:"), + llvm::cl::values(clEnumValN(circt::firrtl::PreserveAggregate::None, + "none", "Preserve no aggregate"), + clEnumValN(circt::firrtl::PreserveAggregate::OneDimVec, + "1d-vec", + "Preserve only 1d vectors of ground type"), + clEnumValN(circt::firrtl::PreserveAggregate::Vec, + "vec", "Preserve only vectors"), + clEnumValN(circt::firrtl::PreserveAggregate::All, + "all", "Preserve vectors and bundles")), + cl::init(circt::firrtl::PreserveAggregate::None), + cl::cat(mainCategory)); static cl::opt preservePublicTypes( "preserve-public-types", @@ -596,7 +606,8 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr, pm.nest().addPass( firrtl::createEmitOMIRPass(omirOutFile)); - if (!disableOptimization && preserveAggregate && mergeConnections) + if (!disableOptimization && + preserveAggregate != firrtl::PreserveAggregate::None && mergeConnections) pm.nest().nest().addPass( firrtl::createMergeConnectionsPass(mergeConnectionsAgggresively));