mirror of https://github.com/llvm/circt.git
[LowerTypes] Add flags to refine aggregate preservation (#3490)
This PR modifies LowerTypes to specify types preserved by aggregate preservation, i.e. `-preserve-aggregate={none, 1d-vec, vec, all}`. Implementation wise, `peelType` takes the enum and selectively lower types.
This commit is contained in:
parent
5d6dcdff80
commit
79f45c2b68
|
@ -28,9 +28,27 @@ std::unique_ptr<mlir::Pass>
|
|||
createLowerFIRRTLAnnotationsPass(bool ignoreUnhandledAnnotations = false,
|
||||
bool ignoreClasslessAnnotations = false);
|
||||
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createLowerFIRRTLTypesPass(bool preserveAggregate = false,
|
||||
bool preservePublicTypes = true);
|
||||
/// Configure which aggregate values will be preserved by the LowerTypes pass.
|
||||
namespace PreserveAggregate {
|
||||
enum PreserveMode {
|
||||
/// Don't preserve aggregate at all. This has been default behaivor and
|
||||
/// compatible with SFC.
|
||||
None,
|
||||
|
||||
/// Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
|
||||
OneDimVec,
|
||||
|
||||
/// Preserve only vectors (e.g. UInt<2>[3][3]).
|
||||
Vec,
|
||||
|
||||
/// Preserve all aggregate values.
|
||||
All,
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> createLowerFIRRTLTypesPass(
|
||||
PreserveAggregate::PreserveMode mode = PreserveAggregate::None,
|
||||
bool preservePublicTypes = true);
|
||||
|
||||
std::unique_ptr<mlir::Pass> createLowerBundleVectorTypesPass();
|
||||
|
||||
|
|
|
@ -56,11 +56,18 @@ def LowerFIRRTLTypes : Pass<"firrtl-lower-types", "firrtl::CircuitOp"> {
|
|||
let options = [
|
||||
Option<"flattenAggregateMemData", "flatten-mem", "bool", "false",
|
||||
"Concat all elements of the aggregate data into a single element.">,
|
||||
Option<"preserveAggregate", "preserve-aggregate", "bool", "false",
|
||||
"Preserve passive aggregate types in the module.">,
|
||||
Option<"preservePublicTypes", "preserve-public-types", "bool",
|
||||
"true", "Force to lower ports of toplevel and external modules even"
|
||||
"when aggregate preservation mode.">
|
||||
"when aggregate preservation mode.">,
|
||||
Option<"preserveAggregate", "preserve-aggregate", "PreserveAggregate::PreserveMode",
|
||||
"PreserveAggregate::None",
|
||||
"Specify aggregate preservation mode",
|
||||
[{::llvm::cl::values(
|
||||
clEnumValN(PreserveAggregate::None, "none", "Preserve no aggregate"),
|
||||
clEnumValN(PreserveAggregate::OneDimVec, "1d-vec", "Preserve 1d vectors"),
|
||||
clEnumValN(PreserveAggregate::Vec, "vec", "Preserve vectors"),
|
||||
clEnumValN(PreserveAggregate::All, "all", "Preserve vectors and bundles")
|
||||
)}]>
|
||||
];
|
||||
let dependentDialects = ["hw::HWDialect"];
|
||||
}
|
||||
|
|
|
@ -100,22 +100,59 @@ static bool hasZeroBitWidth(FIRRTLType type) {
|
|||
});
|
||||
}
|
||||
|
||||
/// Return true if we can preserve the aggregate type. We can a preserve the
|
||||
/// type iff (i) the type is not passive, (ii) the type doesn't contain analog
|
||||
/// and (iii) type don't contain zero bitwidth.
|
||||
static bool isPreservableAggregateType(Type type) {
|
||||
/// Return true if the type is a 1d vector type or ground type.
|
||||
static bool isOneDimVectorType(FIRRTLType type) {
|
||||
return TypeSwitch<FIRRTLType, bool>(type)
|
||||
.Case<BundleType>([&](auto bundle) { return false; })
|
||||
.Case<FVectorType>([&](FVectorType vector) {
|
||||
return vector.getElementType().isGround();
|
||||
})
|
||||
.Default([](auto groundType) { return true; });
|
||||
}
|
||||
|
||||
/// Return true if the type has a bundle type as subtype.
|
||||
static bool containsBundleType(FIRRTLType type) {
|
||||
return TypeSwitch<FIRRTLType, bool>(type)
|
||||
.Case<BundleType>([&](auto bundle) { return true; })
|
||||
.Case<FVectorType>([&](FVectorType vector) {
|
||||
return containsBundleType(vector.getElementType());
|
||||
})
|
||||
.Default([](auto groundType) { return false; });
|
||||
}
|
||||
|
||||
/// Return true if we can preserve the type.
|
||||
static bool isPreservableAggregateType(Type type,
|
||||
PreserveAggregate::PreserveMode mode) {
|
||||
// Return false if no aggregate value is preserved.
|
||||
if (mode == PreserveAggregate::None)
|
||||
return false;
|
||||
|
||||
auto firrtlType = type.cast<FIRRTLType>();
|
||||
return firrtlType.isPassive() && !firrtlType.containsAnalog() &&
|
||||
!hasZeroBitWidth(firrtlType);
|
||||
// We can a preserve the type iff (i) the type is not passive, (ii) the type
|
||||
// doesn't contain analog and (iii) type don't contain zero bitwidth.
|
||||
if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
|
||||
hasZeroBitWidth(firrtlType))
|
||||
return false;
|
||||
|
||||
switch (mode) {
|
||||
case PreserveAggregate::All:
|
||||
return true;
|
||||
case PreserveAggregate::OneDimVec:
|
||||
return isOneDimVectorType(firrtlType);
|
||||
case PreserveAggregate::Vec:
|
||||
return !containsBundleType(firrtlType);
|
||||
default:
|
||||
llvm_unreachable("unexpected mode");
|
||||
}
|
||||
}
|
||||
|
||||
/// Peel one layer of an aggregate type into its components. Type may be
|
||||
/// complex, but empty, in which case fields is empty, but the return is true.
|
||||
static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
|
||||
bool allowedToPreserveAggregate = false) {
|
||||
PreserveAggregate::PreserveMode mode) {
|
||||
// If the aggregate preservation is enabled and the type is preservable,
|
||||
// then just return.
|
||||
if (allowedToPreserveAggregate && isPreservableAggregateType(type))
|
||||
if (isPreservableAggregateType(type, mode))
|
||||
return false;
|
||||
|
||||
return TypeSwitch<Type, bool>(type)
|
||||
|
@ -298,10 +335,11 @@ struct AttrCache {
|
|||
// not.
|
||||
struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
|
||||
|
||||
TypeLoweringVisitor(MLIRContext *context, bool preserveAggregate,
|
||||
TypeLoweringVisitor(MLIRContext *context,
|
||||
PreserveAggregate::PreserveMode preserveAggregate,
|
||||
bool preservePublicTypes, SymbolTable &symTbl,
|
||||
const AttrCache &cache)
|
||||
: context(context), preserveAggregate(preserveAggregate),
|
||||
: context(context), aggregatePreservationMode(preserveAggregate),
|
||||
preservePublicTypes(preservePublicTypes), symTbl(symTbl), cache(cache) {
|
||||
}
|
||||
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
|
||||
|
@ -359,7 +397,8 @@ private:
|
|||
FIRRTLType srcType, FlatBundleFieldEntry field,
|
||||
bool &needsSym, StringRef sym);
|
||||
|
||||
bool isModuleAllowedToPreserveAggregate(FModuleLike moduleLike);
|
||||
PreserveAggregate::PreserveMode
|
||||
getPreservatinoModeForModule(FModuleLike moduleLike);
|
||||
Value getSubWhatever(Value val, size_t index);
|
||||
|
||||
size_t uniqueIdx = 0;
|
||||
|
@ -370,9 +409,8 @@ private:
|
|||
|
||||
MLIRContext *context;
|
||||
|
||||
/// Not to lower passive aggregate types as much as possible if this flag is
|
||||
/// enabled.
|
||||
bool preserveAggregate;
|
||||
/// Aggregate preservation mode.
|
||||
PreserveAggregate::PreserveMode aggregatePreservationMode;
|
||||
|
||||
/// Exteranal modules and toplevel modules should have lowered types if this
|
||||
/// flag is enabled.
|
||||
|
@ -416,23 +454,16 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
/// Return true if we can preserve the arguments of the given module.
|
||||
/// Exteranal modules and toplevel modules are sometimes assumed to have lowered
|
||||
/// types.
|
||||
bool TypeLoweringVisitor::isModuleAllowedToPreserveAggregate(
|
||||
FModuleLike module) {
|
||||
/// Return aggregate preservation mode for the module. If the module has a
|
||||
/// public linkage, then it is not allowed to preserve aggregate values on ports
|
||||
/// unless `preservePublicTypes` flag is disabled.
|
||||
PreserveAggregate::PreserveMode
|
||||
TypeLoweringVisitor::getPreservatinoModeForModule(FModuleLike module) {
|
||||
|
||||
if (!preserveAggregate)
|
||||
return false;
|
||||
|
||||
// If it is not forced to lower toplevel and external modules, it's ok to
|
||||
// preserve.
|
||||
if (!preservePublicTypes)
|
||||
return true;
|
||||
|
||||
if (isa<FExtModuleOp>(module))
|
||||
return false;
|
||||
return !cast<hw::HWModuleLike>(*module).isPublic();
|
||||
if (aggregatePreservationMode != PreserveAggregate::None &&
|
||||
preservePublicTypes && cast<hw::HWModuleLike>(*module).isPublic())
|
||||
return PreserveAggregate::None;
|
||||
return aggregatePreservationMode;
|
||||
}
|
||||
|
||||
Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
|
||||
|
@ -553,7 +584,7 @@ bool TypeLoweringVisitor::lowerProducer(
|
|||
auto srcType = op->getResult(0).getType().cast<FIRRTLType>();
|
||||
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
|
||||
|
||||
if (!peelType(srcType, fieldTypes, preserveAggregate))
|
||||
if (!peelType(srcType, fieldTypes, aggregatePreservationMode))
|
||||
return false;
|
||||
|
||||
SmallVector<Value> lowered;
|
||||
|
@ -715,8 +746,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
|
|||
// Flatten any bundle types.
|
||||
SmallVector<FlatBundleFieldEntry> fieldTypes;
|
||||
auto srcType = newArgs[argIndex].type.cast<FIRRTLType>();
|
||||
if (!peelType(srcType, fieldTypes,
|
||||
isModuleAllowedToPreserveAggregate(module)))
|
||||
if (!peelType(srcType, fieldTypes, getPreservatinoModeForModule(module)))
|
||||
return false;
|
||||
|
||||
for (const auto &field : llvm::enumerate(fieldTypes)) {
|
||||
|
@ -773,8 +803,7 @@ bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
|
|||
SmallVector<FlatBundleFieldEntry> fields;
|
||||
|
||||
// We have to expand connections even if the aggregate preservation is true.
|
||||
if (!peelType(op.getDest().getType(), fields,
|
||||
/* allowedToPreserveAggregate */ false))
|
||||
if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
|
||||
return false;
|
||||
|
||||
// Loop over the leaf aggregates.
|
||||
|
@ -797,8 +826,7 @@ bool TypeLoweringVisitor::visitStmt(StrictConnectOp op) {
|
|||
SmallVector<FlatBundleFieldEntry> fields;
|
||||
|
||||
// We have to expand connections even if the aggregate preservation is true.
|
||||
if (!peelType(op.getDest().getType(), fields,
|
||||
/* allowedToPreserveAggregate */ false))
|
||||
if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
|
||||
return false;
|
||||
|
||||
// Loop over the leaf aggregates.
|
||||
|
@ -833,7 +861,7 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) {
|
|||
SmallVector<FlatBundleFieldEntry> fields;
|
||||
|
||||
// MemOp should have ground types so we can't preserve aggregates.
|
||||
if (!peelType(op.getDataType(), fields, false))
|
||||
if (!peelType(op.getDataType(), fields, PreserveAggregate::None))
|
||||
return false;
|
||||
|
||||
SmallVector<MemOp> newMemories;
|
||||
|
@ -1112,8 +1140,7 @@ bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
|
|||
// UInt type result. That is, first bitcast the aggregate type to a UInt.
|
||||
// Attempt to get the bundle types.
|
||||
SmallVector<FlatBundleFieldEntry> fields;
|
||||
if (peelType(op.getInput().getType(), fields,
|
||||
/* allowedToPreserveAggregate */ false)) {
|
||||
if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) {
|
||||
size_t uptoBits = 0;
|
||||
// Loop over the leaf aggregates and concat each of them to get a UInt.
|
||||
// Bitcast the fields to handle nested aggregate types.
|
||||
|
@ -1176,8 +1203,8 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
|
|||
SmallVector<Direction> newDirs;
|
||||
SmallVector<Attribute> newNames;
|
||||
SmallVector<Attribute> newPortAnno;
|
||||
bool allowedToPreserveAggregate =
|
||||
isModuleAllowedToPreserveAggregate(op.getReferencedModule(symTbl));
|
||||
PreserveAggregate::PreserveMode mode =
|
||||
getPreservatinoModeForModule(op.getReferencedModule(symTbl));
|
||||
|
||||
endFields.push_back(0);
|
||||
bool needsSymbol = false;
|
||||
|
@ -1186,7 +1213,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
|
|||
|
||||
// Flatten any nested bundle types the usual way.
|
||||
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
|
||||
if (!peelType(srcType, fieldTypes, allowedToPreserveAggregate)) {
|
||||
if (!peelType(srcType, fieldTypes, mode)) {
|
||||
newDirs.push_back(op.getPortDirection(i));
|
||||
newNames.push_back(op.getPortName(i));
|
||||
resultTypes.push_back(srcType);
|
||||
|
@ -1294,7 +1321,9 @@ bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
|
|||
|
||||
namespace {
|
||||
struct LowerTypesPass : public LowerFIRRTLTypesBase<LowerTypesPass> {
|
||||
LowerTypesPass(bool preserveAggregateFlag, bool preservePublicTypesFlag) {
|
||||
LowerTypesPass(
|
||||
circt::firrtl::PreserveAggregate::PreserveMode preserveAggregateFlag,
|
||||
bool preservePublicTypesFlag) {
|
||||
preserveAggregate = preserveAggregateFlag;
|
||||
preservePublicTypes = preservePublicTypesFlag;
|
||||
}
|
||||
|
@ -1463,9 +1492,8 @@ void LowerTypesPass::runOnOperation() {
|
|||
|
||||
/// This is the pass constructor.
|
||||
std::unique_ptr<mlir::Pass>
|
||||
circt::firrtl::createLowerFIRRTLTypesPass(bool preserveAggregate,
|
||||
circt::firrtl::createLowerFIRRTLTypesPass(PreserveAggregate::PreserveMode mode,
|
||||
bool preservePublicTypes) {
|
||||
|
||||
return std::make_unique<LowerTypesPass>(preserveAggregate,
|
||||
preservePublicTypes);
|
||||
return std::make_unique<LowerTypesPass>(mode, preservePublicTypes);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all preserve-public-types=false})' %s | FileCheck --check-prefix=NOT_PRESERVE_PUBLIC_TYPES %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=vec})' %s | FileCheck --check-prefix=VEC %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=1d-vec})' %s | FileCheck --check-prefix=1D_VEC %s
|
||||
|
||||
firrtl.circuit "TopLevel" {
|
||||
// CHECK-LABEL: firrtl.extmodule @External(in source_valid: !firrtl.uint<1>)
|
||||
|
@ -10,4 +12,9 @@ firrtl.circuit "TopLevel" {
|
|||
firrtl.module @TopLevel(in %source: !firrtl.bundle<valid: uint<1>>,
|
||||
out %sink: !firrtl.bundle<valid: uint<1>>) {
|
||||
}
|
||||
// CHECK: @Foo(in %a: !firrtl.bundle<a: vector<vector<uint<1>, 2>, 2>>)
|
||||
// VEC: @Foo(in %a_a: !firrtl.vector<vector<uint<1>, 2>, 2>)
|
||||
// 1D_VEC: @Foo(in %a_a_0: !firrtl.vector<uint<1>, 2>, in %a_a_1: !firrtl.vector<uint<1>, 2>)
|
||||
firrtl.module private @Foo(in %a: !firrtl.bundle<a: vector<vector<uint<1>, 2>, 2>>) {
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types)' %s | FileCheck --check-prefixes=CHECK,COMMON %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s
|
||||
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=all})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s
|
||||
|
||||
|
||||
firrtl.circuit "TopLevel" {
|
||||
|
|
|
@ -128,10 +128,20 @@ static cl::opt<bool> replSeqMem(
|
|||
"replace the seq mem for macro replacement and emit relevant metadata"),
|
||||
cl::init(false), cl::cat(mainCategory));
|
||||
|
||||
static cl::opt<bool>
|
||||
preserveAggregate("preserve-aggregate",
|
||||
cl::desc("preserve aggregate types in lower types"),
|
||||
cl::init(false), cl::cat(mainCategory));
|
||||
static cl::opt<circt::firrtl::PreserveAggregate::PreserveMode>
|
||||
preserveAggregate(
|
||||
"preserve-aggregate", cl::desc("Specify input file format:"),
|
||||
llvm::cl::values(clEnumValN(circt::firrtl::PreserveAggregate::None,
|
||||
"none", "Preserve no aggregate"),
|
||||
clEnumValN(circt::firrtl::PreserveAggregate::OneDimVec,
|
||||
"1d-vec",
|
||||
"Preserve only 1d vectors of ground type"),
|
||||
clEnumValN(circt::firrtl::PreserveAggregate::Vec,
|
||||
"vec", "Preserve only vectors"),
|
||||
clEnumValN(circt::firrtl::PreserveAggregate::All,
|
||||
"all", "Preserve vectors and bundles")),
|
||||
cl::init(circt::firrtl::PreserveAggregate::None),
|
||||
cl::cat(mainCategory));
|
||||
|
||||
static cl::opt<bool> preservePublicTypes(
|
||||
"preserve-public-types",
|
||||
|
@ -596,7 +606,8 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
|
|||
pm.nest<firrtl::CircuitOp>().addPass(
|
||||
firrtl::createEmitOMIRPass(omirOutFile));
|
||||
|
||||
if (!disableOptimization && preserveAggregate && mergeConnections)
|
||||
if (!disableOptimization &&
|
||||
preserveAggregate != firrtl::PreserveAggregate::None && mergeConnections)
|
||||
pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
|
||||
firrtl::createMergeConnectionsPass(mergeConnectionsAgggresively));
|
||||
|
||||
|
|
Loading…
Reference in New Issue