[LowerTypes] Add `preservePublicTypes` option (#2424)

This commit adds preservePublicTypes option to preserve port types of 
top-level and exeternal modules. 
In chisel, external modules and top-level modules are implicitly assumed to have
lowered types. Therefore even in the aggregate presevartion mode, we
can't preserve them.
This commit is contained in:
uenoku 2022-01-18 23:51:54 +09:00 committed by GitHub
parent 677d780169
commit 39c0903589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 17 deletions

View File

@ -30,7 +30,8 @@ createLowerFIRRTLAnnotationsPass(bool ignoreUnhandledAnnotations = false,
std::unique_ptr<mlir::Pass>
createLowerFIRRTLTypesPass(bool replSeqMem = false,
bool preserveAggregate = false);
bool preserveAggregate = false,
bool preservePublicTypes = true);
std::unique_ptr<mlir::Pass> createLowerBundleVectorTypesPass();

View File

@ -48,7 +48,10 @@ def LowerFIRRTLTypes : Pass<"firrtl-lower-types", "firrtl::CircuitOp"> {
Option<"flattenAggregateMemData", "flatten-mem", "bool", "false",
"Concat all elements of the aggregate data into a single element.">,
Option<"preserveAggregate", "preserve-aggregate", "bool", "false",
"Preserve passive aggregate types in the module.">
"Preserve passive aggregate types in the module.">,
Option<"preservePublicTypes", "preserve-public-types", "bool",
"true", "Force to lower ports of toplevel and external modules even"
"when aggregate preservation mode.">
];
let dependentDialects = ["hw::HWDialect"];
}

View File

@ -310,8 +310,11 @@ namespace {
// not.
struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
TypeLoweringVisitor(MLIRContext *context, bool f, bool p)
: context(context), flattenAggregateMemData(f), preserveAggregate(p) {}
TypeLoweringVisitor(MLIRContext *context, bool flattenAggregateMemData,
bool preserveAggregate, bool preservePublicTypes)
: context(context), flattenAggregateMemData(flattenAggregateMemData),
preserveAggregate(preserveAggregate),
preservePublicTypes(preservePublicTypes) {}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
@ -355,6 +358,8 @@ private:
llvm::function_ref<Operation *(FlatBundleFieldEntry,
StringRef, ArrayAttr)>
clone);
bool isModuleAllowedToPreserveAggregate(Operation *moduleLike);
Value getSubWhatever(Value val, size_t index);
MLIRContext *context;
@ -366,11 +371,35 @@ private:
/// enabled.
bool preserveAggregate;
/// Exteranal modules and toplevel modules should have lowered types if this
/// flag is enabled.
bool preservePublicTypes;
/// The builder is set and maintained in the main loop.
ImplicitLocOpBuilder *builder;
};
} // namespace
/// Return true if we can preserve the arguments of the given module.
/// Exteranal modules and toplevel modules are sometimes assumed to have lowered
/// types.
bool TypeLoweringVisitor::isModuleAllowedToPreserveAggregate(
Operation *moduleLike) {
if (!preserveAggregate)
return false;
// If it is not forced to lower toplevel and external modules, it's ok to
// preserve.
if (!preservePublicTypes)
return true;
if (isa<FExtModuleOp>(moduleLike))
return false;
auto module = cast<FModuleOp>(moduleLike);
return cast<CircuitOp>(module->getParentOp()).getMainModule() != module;
}
Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
if (BundleType bundle = val.getType().dyn_cast<BundleType>()) {
return builder->create<SubfieldOp>(val, index);
@ -531,7 +560,8 @@ bool TypeLoweringVisitor::lowerArg(Operation *module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = newArgs[argIndex].type.cast<FIRRTLType>();
if (!peelType(srcType, fieldTypes, preserveAggregate))
if (!peelType(srcType, fieldTypes,
isModuleAllowedToPreserveAggregate(module)))
return false;
for (auto field : llvm::enumerate(fieldTypes)) {
@ -917,7 +947,6 @@ bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
// Lower the module block arguments.
SmallVector<unsigned> argsToRemove;
auto newArgs = extModule.getPorts();
for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
SmallVector<Value> lowering;
@ -1193,6 +1222,8 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
bool allowedToPreserveAggregate =
isModuleAllowedToPreserveAggregate(op.getReferencedModule());
endFields.push_back(0);
bool hasDontTouch = false;
@ -1201,7 +1232,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
// Flatten any nested bundle types the usual way.
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
if (!peelType(srcType, fieldTypes, preserveAggregate)) {
if (!peelType(srcType, fieldTypes, allowedToPreserveAggregate)) {
newDirs.push_back(op.getPortDirection(i));
newNames.push_back(op.getPortName(i));
resultTypes.push_back(srcType);
@ -1299,9 +1330,11 @@ bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
namespace {
struct LowerTypesPass : public LowerFIRRTLTypesBase<LowerTypesPass> {
LowerTypesPass(bool f, bool p) {
flattenAggregateMemData = f;
preserveAggregate = p;
LowerTypesPass(bool flattenAggregateMemDataFlag, bool preserveAggregateFlag,
bool preservePublicTypesFlag) {
flattenAggregateMemData = flattenAggregateMemDataFlag;
preserveAggregate = preserveAggregateFlag;
preservePublicTypes = preservePublicTypesFlag;
}
void runOnOperation() override;
};
@ -1315,14 +1348,15 @@ void LowerTypesPass::runOnOperation() {
mlir::parallelForEachN(&getContext(), 0, ops.size(), [&](auto index) {
TypeLoweringVisitor(&getContext(), flattenAggregateMemData,
preserveAggregate)
preserveAggregate, preservePublicTypes)
.lowerModule(ops[index]);
});
}
/// This is the pass constructor.
std::unique_ptr<mlir::Pass>
circt::firrtl::createLowerFIRRTLTypesPass(bool replSeqMem,
bool preserveAggregate) {
return std::make_unique<LowerTypesPass>(replSeqMem, preserveAggregate);
std::unique_ptr<mlir::Pass> circt::firrtl::createLowerFIRRTLTypesPass(
bool replSeqMem, bool preserveAggregate, bool preservePublicTypes) {
return std::make_unique<LowerTypesPass>(replSeqMem, preserveAggregate,
preservePublicTypes);
}

View File

@ -0,0 +1,13 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s
firrtl.circuit "TopLevel" {
// CHECK-LABEL: firrtl.extmodule @External(in source_valid: !firrtl.uint<1>)
// CHECK-LABEL: firrtl.module @TopLevel(in %source_valid: !firrtl.uint<1>, out %sink_valid: !firrtl.uint<1>)
// NOT_PRESERVE_PUBLIC_TYPES-LABEL: firrtl.extmodule @External(in source: !firrtl.bundle<valid: uint<1>>)
// NOT_PRESERVE_PUBLIC_TYPES-LABEL: firrtl.module @TopLevel(in %source: !firrtl.bundle<valid: uint<1>>, out %sink: !firrtl.bundle<valid: uint<1>>)
firrtl.extmodule @External(in source: !firrtl.bundle<valid: uint<1>>)
firrtl.module @TopLevel(in %source: !firrtl.bundle<valid: uint<1>>,
out %sink: !firrtl.bundle<valid: uint<1>>) {
}
}

View File

@ -118,6 +118,12 @@ static cl::opt<bool>
preserveAggregate("preserve-aggregate",
cl::desc("preserve aggregate types in lower types"),
cl::init(false));
static cl::opt<bool> preservePublicTypes(
"preserve-public-types",
cl::desc("force to lower ports of toplevel and external modules"),
cl::init(true));
static cl::opt<std::string>
replSeqMemCircuit("repl-seq-mem-circuit",
cl::desc("circuit root for seq mem metadata"),
@ -356,8 +362,8 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
// The input mlir file could be firrtl dialect so we might need to clean
// things up.
if (lowerTypes) {
pm.addNestedPass<firrtl::CircuitOp>(
firrtl::createLowerFIRRTLTypesPass(replSeqMem, preserveAggregate));
pm.addNestedPass<firrtl::CircuitOp>(firrtl::createLowerFIRRTLTypesPass(
replSeqMem, preserveAggregate, preservePublicTypes));
// Only enable expand whens if lower types is also enabled.
if (expandWhens) {
auto &modulePM = pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>();