[LowerTypes] Add flags to refine aggregate preservation (#3490)

This PR modifies LowerTypes to specify types preserved by aggregate preservation, i.e. `-preserve-aggregate={none, 1d-vec, vec, all}`.
Implementation wise, `peelType` takes the enum and selectively lower types.
This commit is contained in:
Hideto Ueno 2022-07-18 19:14:50 +09:00 committed by GitHub
parent 5d6dcdff80
commit 79f45c2b68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 61 deletions

View File

@ -28,9 +28,27 @@ std::unique_ptr<mlir::Pass>
createLowerFIRRTLAnnotationsPass(bool ignoreUnhandledAnnotations = false,
bool ignoreClasslessAnnotations = false);
std::unique_ptr<mlir::Pass>
createLowerFIRRTLTypesPass(bool preserveAggregate = false,
bool preservePublicTypes = true);
/// Configure which aggregate values will be preserved by the LowerTypes pass.
namespace PreserveAggregate {
enum PreserveMode {
/// Don't preserve aggregate at all. This has been default behaivor and
/// compatible with SFC.
None,
/// Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
OneDimVec,
/// Preserve only vectors (e.g. UInt<2>[3][3]).
Vec,
/// Preserve all aggregate values.
All,
};
}
std::unique_ptr<mlir::Pass> createLowerFIRRTLTypesPass(
PreserveAggregate::PreserveMode mode = PreserveAggregate::None,
bool preservePublicTypes = true);
std::unique_ptr<mlir::Pass> createLowerBundleVectorTypesPass();

View File

@ -56,11 +56,18 @@ def LowerFIRRTLTypes : Pass<"firrtl-lower-types", "firrtl::CircuitOp"> {
let options = [
Option<"flattenAggregateMemData", "flatten-mem", "bool", "false",
"Concat all elements of the aggregate data into a single element.">,
Option<"preserveAggregate", "preserve-aggregate", "bool", "false",
"Preserve passive aggregate types in the module.">,
Option<"preservePublicTypes", "preserve-public-types", "bool",
"true", "Force to lower ports of toplevel and external modules even"
"when aggregate preservation mode.">
"when aggregate preservation mode.">,
Option<"preserveAggregate", "preserve-aggregate", "PreserveAggregate::PreserveMode",
"PreserveAggregate::None",
"Specify aggregate preservation mode",
[{::llvm::cl::values(
clEnumValN(PreserveAggregate::None, "none", "Preserve no aggregate"),
clEnumValN(PreserveAggregate::OneDimVec, "1d-vec", "Preserve 1d vectors"),
clEnumValN(PreserveAggregate::Vec, "vec", "Preserve vectors"),
clEnumValN(PreserveAggregate::All, "all", "Preserve vectors and bundles")
)}]>
];
let dependentDialects = ["hw::HWDialect"];
}

View File

@ -100,22 +100,59 @@ static bool hasZeroBitWidth(FIRRTLType type) {
});
}
/// Return true if we can preserve the aggregate type. We can a preserve the
/// type iff (i) the type is not passive, (ii) the type doesn't contain analog
/// and (iii) type don't contain zero bitwidth.
static bool isPreservableAggregateType(Type type) {
/// Return true if the type is a 1d vector type or ground type.
static bool isOneDimVectorType(FIRRTLType type) {
return TypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) { return false; })
.Case<FVectorType>([&](FVectorType vector) {
return vector.getElementType().isGround();
})
.Default([](auto groundType) { return true; });
}
/// Return true if the type has a bundle type as subtype.
static bool containsBundleType(FIRRTLType type) {
return TypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) { return true; })
.Case<FVectorType>([&](FVectorType vector) {
return containsBundleType(vector.getElementType());
})
.Default([](auto groundType) { return false; });
}
/// Return true if we can preserve the type.
static bool isPreservableAggregateType(Type type,
PreserveAggregate::PreserveMode mode) {
// Return false if no aggregate value is preserved.
if (mode == PreserveAggregate::None)
return false;
auto firrtlType = type.cast<FIRRTLType>();
return firrtlType.isPassive() && !firrtlType.containsAnalog() &&
!hasZeroBitWidth(firrtlType);
// We can a preserve the type iff (i) the type is not passive, (ii) the type
// doesn't contain analog and (iii) type don't contain zero bitwidth.
if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
hasZeroBitWidth(firrtlType))
return false;
switch (mode) {
case PreserveAggregate::All:
return true;
case PreserveAggregate::OneDimVec:
return isOneDimVectorType(firrtlType);
case PreserveAggregate::Vec:
return !containsBundleType(firrtlType);
default:
llvm_unreachable("unexpected mode");
}
}
/// Peel one layer of an aggregate type into its components. Type may be
/// complex, but empty, in which case fields is empty, but the return is true.
static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
bool allowedToPreserveAggregate = false) {
PreserveAggregate::PreserveMode mode) {
// If the aggregate preservation is enabled and the type is preservable,
// then just return.
if (allowedToPreserveAggregate && isPreservableAggregateType(type))
if (isPreservableAggregateType(type, mode))
return false;
return TypeSwitch<Type, bool>(type)
@ -298,10 +335,11 @@ struct AttrCache {
// not.
struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
TypeLoweringVisitor(MLIRContext *context, bool preserveAggregate,
TypeLoweringVisitor(MLIRContext *context,
PreserveAggregate::PreserveMode preserveAggregate,
bool preservePublicTypes, SymbolTable &symTbl,
const AttrCache &cache)
: context(context), preserveAggregate(preserveAggregate),
: context(context), aggregatePreservationMode(preserveAggregate),
preservePublicTypes(preservePublicTypes), symTbl(symTbl), cache(cache) {
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
@ -359,7 +397,8 @@ private:
FIRRTLType srcType, FlatBundleFieldEntry field,
bool &needsSym, StringRef sym);
bool isModuleAllowedToPreserveAggregate(FModuleLike moduleLike);
PreserveAggregate::PreserveMode
getPreservatinoModeForModule(FModuleLike moduleLike);
Value getSubWhatever(Value val, size_t index);
size_t uniqueIdx = 0;
@ -370,9 +409,8 @@ private:
MLIRContext *context;
/// Not to lower passive aggregate types as much as possible if this flag is
/// enabled.
bool preserveAggregate;
/// Aggregate preservation mode.
PreserveAggregate::PreserveMode aggregatePreservationMode;
/// Exteranal modules and toplevel modules should have lowered types if this
/// flag is enabled.
@ -416,23 +454,16 @@ private:
};
} // namespace
/// Return true if we can preserve the arguments of the given module.
/// Exteranal modules and toplevel modules are sometimes assumed to have lowered
/// types.
bool TypeLoweringVisitor::isModuleAllowedToPreserveAggregate(
FModuleLike module) {
/// Return aggregate preservation mode for the module. If the module has a
/// public linkage, then it is not allowed to preserve aggregate values on ports
/// unless `preservePublicTypes` flag is disabled.
PreserveAggregate::PreserveMode
TypeLoweringVisitor::getPreservatinoModeForModule(FModuleLike module) {
if (!preserveAggregate)
return false;
// If it is not forced to lower toplevel and external modules, it's ok to
// preserve.
if (!preservePublicTypes)
return true;
if (isa<FExtModuleOp>(module))
return false;
return !cast<hw::HWModuleLike>(*module).isPublic();
if (aggregatePreservationMode != PreserveAggregate::None &&
preservePublicTypes && cast<hw::HWModuleLike>(*module).isPublic())
return PreserveAggregate::None;
return aggregatePreservationMode;
}
Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
@ -553,7 +584,7 @@ bool TypeLoweringVisitor::lowerProducer(
auto srcType = op->getResult(0).getType().cast<FIRRTLType>();
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
if (!peelType(srcType, fieldTypes, preserveAggregate))
if (!peelType(srcType, fieldTypes, aggregatePreservationMode))
return false;
SmallVector<Value> lowered;
@ -715,8 +746,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = newArgs[argIndex].type.cast<FIRRTLType>();
if (!peelType(srcType, fieldTypes,
isModuleAllowedToPreserveAggregate(module)))
if (!peelType(srcType, fieldTypes, getPreservatinoModeForModule(module)))
return false;
for (const auto &field : llvm::enumerate(fieldTypes)) {
@ -773,8 +803,7 @@ bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
SmallVector<FlatBundleFieldEntry> fields;
// We have to expand connections even if the aggregate preservation is true.
if (!peelType(op.getDest().getType(), fields,
/* allowedToPreserveAggregate */ false))
if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
return false;
// Loop over the leaf aggregates.
@ -797,8 +826,7 @@ bool TypeLoweringVisitor::visitStmt(StrictConnectOp op) {
SmallVector<FlatBundleFieldEntry> fields;
// We have to expand connections even if the aggregate preservation is true.
if (!peelType(op.getDest().getType(), fields,
/* allowedToPreserveAggregate */ false))
if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
return false;
// Loop over the leaf aggregates.
@ -833,7 +861,7 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) {
SmallVector<FlatBundleFieldEntry> fields;
// MemOp should have ground types so we can't preserve aggregates.
if (!peelType(op.getDataType(), fields, false))
if (!peelType(op.getDataType(), fields, PreserveAggregate::None))
return false;
SmallVector<MemOp> newMemories;
@ -1112,8 +1140,7 @@ bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
// UInt type result. That is, first bitcast the aggregate type to a UInt.
// Attempt to get the bundle types.
SmallVector<FlatBundleFieldEntry> fields;
if (peelType(op.getInput().getType(), fields,
/* allowedToPreserveAggregate */ false)) {
if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) {
size_t uptoBits = 0;
// Loop over the leaf aggregates and concat each of them to get a UInt.
// Bitcast the fields to handle nested aggregate types.
@ -1176,8 +1203,8 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
bool allowedToPreserveAggregate =
isModuleAllowedToPreserveAggregate(op.getReferencedModule(symTbl));
PreserveAggregate::PreserveMode mode =
getPreservatinoModeForModule(op.getReferencedModule(symTbl));
endFields.push_back(0);
bool needsSymbol = false;
@ -1186,7 +1213,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
// Flatten any nested bundle types the usual way.
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
if (!peelType(srcType, fieldTypes, allowedToPreserveAggregate)) {
if (!peelType(srcType, fieldTypes, mode)) {
newDirs.push_back(op.getPortDirection(i));
newNames.push_back(op.getPortName(i));
resultTypes.push_back(srcType);
@ -1294,7 +1321,9 @@ bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
namespace {
struct LowerTypesPass : public LowerFIRRTLTypesBase<LowerTypesPass> {
LowerTypesPass(bool preserveAggregateFlag, bool preservePublicTypesFlag) {
LowerTypesPass(
circt::firrtl::PreserveAggregate::PreserveMode preserveAggregateFlag,
bool preservePublicTypesFlag) {
preserveAggregate = preserveAggregateFlag;
preservePublicTypes = preservePublicTypesFlag;
}
@ -1463,9 +1492,8 @@ void LowerTypesPass::runOnOperation() {
/// This is the pass constructor.
std::unique_ptr<mlir::Pass>
circt::firrtl::createLowerFIRRTLTypesPass(bool preserveAggregate,
circt::firrtl::createLowerFIRRTLTypesPass(PreserveAggregate::PreserveMode mode,
bool preservePublicTypes) {
return std::make_unique<LowerTypesPass>(preserveAggregate,
preservePublicTypes);
return std::make_unique<LowerTypesPass>(mode, preservePublicTypes);
}

View File

@ -1,5 +1,7 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=vec})' %s | FileCheck --check-prefix=VEC %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=1d-vec})' %s | FileCheck --check-prefix=1D_VEC %s
firrtl.circuit "TopLevel" {
// CHECK-LABEL: firrtl.extmodule @External(in source_valid: !firrtl.uint<1>)
@ -10,4 +12,9 @@ firrtl.circuit "TopLevel" {
firrtl.module @TopLevel(in %source: !firrtl.bundle<valid: uint<1>>,
out %sink: !firrtl.bundle<valid: uint<1>>) {
}
// CHECK: @Foo(in %a: !firrtl.bundle<a: vector<vector<uint<1>, 2>, 2>>)
// VEC: @Foo(in %a_a: !firrtl.vector<vector<uint<1>, 2>, 2>)
// 1D_VEC: @Foo(in %a_a_0: !firrtl.vector<uint<1>, 2>, in %a_a_1: !firrtl.vector<uint<1>, 2>)
firrtl.module private @Foo(in %a: !firrtl.bundle<a: vector<vector<uint<1>, 2>, 2>>) {
}
}

View File

@ -1,5 +1,5 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types)' %s | FileCheck --check-prefixes=CHECK,COMMON %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s
firrtl.circuit "TopLevel" {

View File

@ -128,10 +128,20 @@ static cl::opt<bool> replSeqMem(
"replace the seq mem for macro replacement and emit relevant metadata"),
cl::init(false), cl::cat(mainCategory));
static cl::opt<bool>
preserveAggregate("preserve-aggregate",
cl::desc("preserve aggregate types in lower types"),
cl::init(false), cl::cat(mainCategory));
static cl::opt<circt::firrtl::PreserveAggregate::PreserveMode>
preserveAggregate(
"preserve-aggregate", cl::desc("Specify input file format:"),
llvm::cl::values(clEnumValN(circt::firrtl::PreserveAggregate::None,
"none", "Preserve no aggregate"),
clEnumValN(circt::firrtl::PreserveAggregate::OneDimVec,
"1d-vec",
"Preserve only 1d vectors of ground type"),
clEnumValN(circt::firrtl::PreserveAggregate::Vec,
"vec", "Preserve only vectors"),
clEnumValN(circt::firrtl::PreserveAggregate::All,
"all", "Preserve vectors and bundles")),
cl::init(circt::firrtl::PreserveAggregate::None),
cl::cat(mainCategory));
static cl::opt<bool> preservePublicTypes(
"preserve-public-types",
@ -596,7 +606,8 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
pm.nest<firrtl::CircuitOp>().addPass(
firrtl::createEmitOMIRPass(omirOutFile));
if (!disableOptimization && preserveAggregate && mergeConnections)
if (!disableOptimization &&
preserveAggregate != firrtl::PreserveAggregate::None && mergeConnections)
pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
firrtl::createMergeConnectionsPass(mergeConnectionsAgggresively));