[FIRRTL] Add a memory flattening pass (#2967)

Create a new pass to flatten the aggregate memory instead of doing it in LowerTypes.
This commit removes the logic from LowerTypes and moves it into a new pass that 
should run before LowerTypes.
This commit is contained in:
Prithayan Barua 2022-04-27 18:24:19 -07:00 committed by GitHub
parent b9c82fa7bf
commit ac5036c3b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 417 additions and 175 deletions

View File

@ -29,8 +29,7 @@ createLowerFIRRTLAnnotationsPass(bool ignoreUnhandledAnnotations = false,
bool ignoreClasslessAnnotations = false);
std::unique_ptr<mlir::Pass>
createLowerFIRRTLTypesPass(bool replSeqMem = false,
bool preserveAggregate = false,
createLowerFIRRTLTypesPass(bool preserveAggregate = false,
bool preservePublicTypes = true);
std::unique_ptr<mlir::Pass> createLowerBundleVectorTypesPass();
@ -46,6 +45,8 @@ std::unique_ptr<mlir::Pass> createInlinerPass();
std::unique_ptr<mlir::Pass> createInferReadWritePass();
std::unique_ptr<mlir::Pass> createLowerMemoryPass();
std::unique_ptr<mlir::Pass> createBlackBoxMemoryPass();
std::unique_ptr<mlir::Pass>

View File

@ -409,4 +409,12 @@ def InferReadWrite : Pass<"firrtl-infer-rw", "firrtl::FModuleOp"> {
let constructor = "circt::firrtl::createInferReadWritePass()";
}
def LowerMemory : Pass<"firrtl-lower-memory", "firrtl::FModuleOp"> {
let summary = "Flatten aggregate memory data to a UInt";
let description = [{
This pass flattens aggregate memory data field into a UInt, and
inserts the appropriate logic to access the data.
}];
let constructor = "circt::firrtl::createLowerMemoryPass()";
}
#endif // CIRCT_DIALECT_FIRRTL_PASSES_TD

View File

@ -15,6 +15,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
InferWidths.cpp
LowerAnnotations.cpp
LowerCHIRRTL.cpp
LowerMemory.cpp
LowerTypes.cpp
MergeConnections.cpp
ModuleInliner.cpp

View File

@ -0,0 +1,230 @@
//===- LowerMemory.cpp - Lower Memory Pass -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the LowerMemory pass.
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h"
#include "circt/Dialect/FIRRTL/Namespace.h"
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "lower-memory"
using namespace circt;
using namespace firrtl;
namespace {
struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {
/// This pass flattens the aggregate data of memory into a UInt, and inserts
/// appropriate bitcasts to access the data.
void runOnOperation() override {
LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
<< getOperation().getName());
ModuleNamespace modNamespace(getOperation());
SmallVector<Operation *> opsToErase;
auto hasSubAnno = [&](MemOp op) -> bool {
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
for (auto attr : op.getPortAnnotation(portIdx))
if (auto subAnno = attr.dyn_cast<SubAnnotationAttr>())
return true;
return false;
};
getOperation().getBody()->walk([&](MemOp memOp) {
LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
// The vector of leaf elements type after flattening the data.
SmallVector<IntType> flatMemType;
// MaskGranularity : how many bits each mask bit controls.
size_t maskGran = 1;
// Total mask bitwidth after flattening.
uint32_t totalmaskWidths = 0;
// How many mask bits each field type requires.
SmallVector<unsigned> maskWidths;
// If subannotations present on aggregate fields, we cannot flatten the
// memory. It must be split into one memory per aggregate field.
// Do not overwrite the pass flag!
if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
return;
SmallVector<Operation *, 8> flatData;
SmallVector<int32_t> memWidths;
// Get the width of individual aggregate leaf elements.
for (auto f : flatMemType) {
LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
memWidths.push_back(f.getWidth().getValue());
}
maskGran = memWidths[0];
size_t memFlatWidth = 0;
// Compute the GCD of all data bitwidths.
for (auto w : memWidths) {
memFlatWidth += w;
maskGran = llvm::GreatestCommonDivisor64(maskGran, w);
}
for (auto w : memWidths) {
// How many mask bits required for each flattened field.
auto mWidth = w / maskGran;
maskWidths.push_back(mWidth);
totalmaskWidths += mWidth;
}
// Now create a new memory of type flattened data.
// ----------------------------------------------
SmallVector<Type, 8> ports;
SmallVector<Attribute, 8> portNames;
auto *context = memOp.getContext();
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
// Create a new memoty data type of unsigned and computed width.
auto flatType = UIntType::get(context, memFlatWidth);
auto opPorts = memOp.getPorts();
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
auto port = opPorts[portIdx];
ports.push_back(MemOp::getTypeForPort(memOp.depth(), flatType,
port.second, totalmaskWidths));
portNames.push_back(port.first);
}
auto flatMem = builder.create<MemOp>(
ports, memOp.readLatency(), memOp.writeLatency(), memOp.depth(),
memOp.ruw(), builder.getArrayAttr(portNames), memOp.nameAttr(),
memOp.annotations(), memOp.portAnnotations(), memOp.inner_symAttr(),
memOp.groupIDAttr());
// Hook up the new memory to the wires the old memory was replaced with.
for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
++index) {
auto result = memOp.getResult(index);
auto wire = builder.create<WireOp>(
result.getType(),
(memOp.name() + "_" + memOp.getPortName(index).getValue()).str());
result.replaceAllUsesWith(wire.getResult());
result = wire;
auto newResult = flatMem.getResult(index);
auto rType = result.getType().cast<BundleType>();
for (size_t fieldIndex = 0, fend = rType.getNumElements();
fieldIndex != fend; ++fieldIndex) {
auto name = rType.getElement(fieldIndex).name.getValue();
auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
Value newField = builder.create<SubfieldOp>(newResult, fieldIndex);
// data and mask depend on the memory type which was split. They can
// also go both directions, depending on the port direction.
if (!(name == "data" || name == "mask" || name == "wdata" ||
name == "wmask" || name == "rdata")) {
mkConnect(&builder, newField, oldField);
continue;
}
Value realOldField = oldField;
if (rType.getElement(fieldIndex).isFlip) {
// Cast the memory read data from flat type to aggregate.
newField = builder.createOrFold<BitCastOp>(
oldField.getType().cast<FIRRTLType>(), newField);
// Write the aggregate read data.
mkConnect(&builder, realOldField, newField);
} else {
// Cast the input aggregate write data to flat type.
// Cast the input aggregate write data to flat type.
auto newFieldType = newField.getType().cast<FIRRTLType>();
auto oldFieldBitWidth = getBitWidth(oldField.getType());
// Following condition is true, if a data field is 0 bits. Then
// newFieldType is of smaller bits than old.
if (getBitWidth(newFieldType) != oldFieldBitWidth.getValue())
newFieldType =
UIntType::get(context, oldFieldBitWidth.getValue());
realOldField = builder.create<BitCastOp>(newFieldType, oldField);
// Mask bits require special handling, since some of the mask bits
// need to be repeated, direct bitcasting wouldn't work. Depending
// on the mask granularity, some mask bits will be repeated.
if ((name == "mask" || name == "wmask") &&
(maskWidths.size() != totalmaskWidths)) {
Value catMasks;
for (auto m : llvm::enumerate(maskWidths)) {
// Get the mask bit.
auto mBit = builder.createOrFold<BitsPrimOp>(
realOldField, m.index(), m.index());
// Check how many times the mask bit needs to be prepend.
for (size_t repeat = 0; repeat < m.value(); repeat++)
if ((m.index() == 0 && repeat == 0) || !catMasks)
catMasks = mBit;
else
catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
}
realOldField = catMasks;
}
// Now set the mask or write data.
// Ensure that the types match.
mkConnect(&builder, newField,
builder.createOrFold<BitCastOp>(
newField.getType().cast<FIRRTLType>(), realOldField));
}
}
}
memOp.erase();
return;
});
}
private:
// Convert an aggregate type into a flat list of fields.
// This is used to flatten the aggregate memory datatype.
// Recursively populate the results with each ground type field.
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
return TypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) {
for (auto &elt : bundle)
if (!flatten(elt.type))
return false;
return true;
})
.Case<FVectorType>([&](auto vector) {
for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
if (!flatten(vector.getElementType()))
return false;
return true;
})
.Case<IntType>([&](auto iType) {
results.push_back({iType});
return iType.getWidth().hasValue();
})
.Default([&](auto) { return false; });
};
if (flatten(type))
return true;
return false;
}
Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
if (BundleType bundle = val.getType().dyn_cast<BundleType>())
return builder->create<SubfieldOp>(val, index);
if (FVectorType fvector = val.getType().dyn_cast<FVectorType>())
return builder->create<SubindexOp>(val, index);
llvm_unreachable("Unknown aggregate type");
return nullptr;
}
void mkConnect(ImplicitLocOpBuilder *builder, Value dst, Value src) {
auto dstType = dst.getType().cast<FIRRTLType>();
auto srcType = src.getType().cast<FIRRTLType>();
if (srcType == dstType)
builder->create<StrictConnectOp>(dst, src);
else
builder->create<ConnectOp>(dst, src);
}
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> circt::firrtl::createLowerMemoryPass() {
return std::make_unique<LowerMemoryPass>();
}

View File

@ -304,11 +304,10 @@ struct AttrCache {
// not.
struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
TypeLoweringVisitor(MLIRContext *context, bool flattenAggregateMemData,
bool preserveAggregate, bool preservePublicTypes,
SymbolTable &symTbl, const AttrCache &cache)
: context(context), flattenAggregateMemData(flattenAggregateMemData),
preserveAggregate(preserveAggregate),
TypeLoweringVisitor(MLIRContext *context, bool preserveAggregate,
bool preservePublicTypes, SymbolTable &symTbl,
const AttrCache &cache)
: context(context), preserveAggregate(preserveAggregate),
preservePublicTypes(preservePublicTypes), symTbl(symTbl), cache(cache) {
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
@ -376,9 +375,6 @@ private:
}
MLIRContext *context;
/// Create a single memory from an aggregate type (instead of one per field)
/// if this flag is enabled.
bool flattenAggregateMemData;
/// Not to lower passive aggregate types as much as possible if this flag is
/// enabled.
@ -932,35 +928,6 @@ bool TypeLoweringVisitor::visitStmt(WhenOp op) {
return false; // don't delete the when!
}
// Convert an aggregate type into a flat list of fields.
// This is used to flatten the aggregate memory datatype.
// Recursively populate the results with each ground type field.
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
return TypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) {
for (auto &elt : bundle)
if (!flatten(elt.type))
return false;
return true;
})
.Case<FVectorType>([&](auto vector) {
for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
if (!flatten(vector.getElementType()))
return false;
return true;
})
.Case<IntType>([&](auto iType) {
results.push_back({iType});
return iType.getWidth().hasValue();
})
.Default([&](auto) { return false; });
};
if (flatten(type))
return true;
return false;
}
/// Lower memory operations. A new memory is created for every leaf
/// element in a memory's data type.
bool TypeLoweringVisitor::visitDecl(MemOp op) {
@ -983,78 +950,13 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) {
oldPorts.push_back(wire);
result.replaceAllUsesWith(wire.getResult());
}
// The vector of leaf elements type after flattening the data.
SmallVector<IntType> flatMemType;
// MaskGranularity : how many bits each mask bit controls.
size_t maskGran = 1;
// Total mask bitwidth after flattening.
uint32_t totalmaskWidths = 0;
// How many mask bits each field type requires.
SmallVector<unsigned> maskWidths;
auto hasSubAnno = [&]() -> bool {
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
for (auto attr : op.getPortAnnotation(portIdx)) {
if (auto subAnno = attr.dyn_cast<SubAnnotationAttr>())
return true;
}
return false;
};
// If subannotations present on aggregate fields, we cannot flatten the
// memory. It must be split into one memory per aggregate field.
// Do not overwrite the pass flag!
auto localFlattenAggregateMemData = flattenAggregateMemData;
if (localFlattenAggregateMemData)
if (hasSubAnno() || !flattenType(op.getDataType(), flatMemType))
localFlattenAggregateMemData = false;
if (localFlattenAggregateMemData) {
SmallVector<Operation *, 8> flatData;
SmallVector<int32_t> memWidths;
// Get the width of individual aggregate leaf elements.
for (auto f : flatMemType)
memWidths.push_back(f.getWidth().getValue());
maskGran = memWidths[0];
size_t memFlatWidth = 0;
// Compute the GCD of all data bitwidths.
for (auto w : memWidths) {
memFlatWidth += w;
maskGran = llvm::GreatestCommonDivisor64(maskGran, w);
}
for (auto w : memWidths) {
// How many mask bits required for each flattened field.
auto mWidth = w / maskGran;
maskWidths.push_back(mWidth);
totalmaskWidths += mWidth;
}
// Now create a new memory of type flattened data.
// ----------------------------------------------
SmallVector<Type, 8> ports;
SmallVector<Attribute, 8> portNames;
// Create a new memoty data type of unsigned and computed width.
auto flatType = UIntType::get(context, memFlatWidth);
auto opPorts = op.getPorts();
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
auto port = opPorts[portIdx];
ports.push_back(MemOp::getTypeForPort(op.depth(), flatType, port.second,
totalmaskWidths));
portNames.push_back(port.first);
}
auto flatMem = builder->create<MemOp>(
ports, op.readLatency(), op.writeLatency(), op.depth(), op.ruw(),
portNames, op.name(), op.annotations().getValue(),
op.portAnnotations().getValue(), op.inner_symAttr());
// Done creating the memory.
// ----------------------------------------------
newMemories.push_back(flatMem);
} else {
// Memory for each field
for (auto field : fields)
newMemories.push_back(cloneMemWithNewType(builder, op, field));
}
// Hook up the new memories to the wires the old memory was replaced with.
for (size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
auto result = oldPorts[index];
@ -1067,55 +969,6 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) {
// go both directions, depending on the port direction.
if (name == "data" || name == "mask" || name == "wdata" ||
name == "wmask" || name == "rdata") {
if (localFlattenAggregateMemData) {
// If memory was flattened instead of one memory per aggregate field.
Value newField =
getSubWhatever(newMemories[0].getResult(index), fieldIndex);
Value realOldField = oldField;
if (rType.getElement(fieldIndex).isFlip) {
// Cast the memory read data from flat type to aggregate.
newField = builder->createOrFold<BitCastOp>(
oldField.getType().cast<FIRRTLType>(), newField);
// Write the aggregate read data.
mkConnect(builder, realOldField, newField);
} else {
// Cast the input aggregate write data to flat type.
auto newFieldType = newField.getType().cast<FIRRTLType>();
auto oldFieldBitWidth = getBitWidth(oldField.getType());
// Following condition is true, if a data field is 0 bits. Then
// newFieldType is of smaller bits than old.
if (getBitWidth(newFieldType) != oldFieldBitWidth.getValue())
newFieldType =
UIntType::get(context, oldFieldBitWidth.getValue());
realOldField = builder->create<BitCastOp>(newFieldType, oldField);
// Mask bits require special handling, since some of the mask bits
// need to be repeated, direct bitcasting wouldn't work. Depending
// on the mask granularity, some mask bits will be repeated.
// This also handles the case, when some of the data fields are 0
// bit, then the mask bits for zero bit data fields must be ignored.
if ((name == "mask" || name == "wmask") &&
(maskWidths.size() != totalmaskWidths)) {
Value catMasks;
for (auto m : llvm::enumerate(maskWidths)) {
// Get the mask bit.
auto mBit = builder->createOrFold<BitsPrimOp>(
realOldField, m.index(), m.index());
// Check how many times the mask bit needs to be prepend.
for (size_t repeat = 0; repeat < m.value(); repeat++)
if ((m.index() == 0 && repeat == 0) || !catMasks)
catMasks = mBit;
else
catMasks = builder->createOrFold<CatPrimOp>(mBit, catMasks);
}
realOldField = catMasks;
}
// Now set the mask or write data.
// Ensure that the types match.
mkConnect(builder, newField,
builder->createOrFold<BitCastOp>(
newField.getType().cast<FIRRTLType>(), realOldField));
}
} else {
for (auto field : fields) {
auto realOldField = getSubWhatever(oldField, field.index);
auto newField = getSubWhatever(
@ -1124,7 +977,6 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) {
std::swap(realOldField, newField);
mkConnect(builder, newField, realOldField);
}
}
} else {
for (auto mem : newMemories) {
auto newField =
@ -1530,9 +1382,7 @@ bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
namespace {
struct LowerTypesPass : public LowerFIRRTLTypesBase<LowerTypesPass> {
LowerTypesPass(bool flattenAggregateMemDataFlag, bool preserveAggregateFlag,
bool preservePublicTypesFlag) {
flattenAggregateMemData = flattenAggregateMemDataFlag;
LowerTypesPass(bool preserveAggregateFlag, bool preservePublicTypesFlag) {
preserveAggregate = preserveAggregateFlag;
preservePublicTypes = preservePublicTypesFlag;
}
@ -1567,9 +1417,8 @@ void LowerTypesPass::runOnOperation() {
std::mutex nlaAppendLock;
// This lambda, executes in parallel for each Op within the circt.
auto lowerModules = [&](FModuleLike op) -> void {
auto tl = TypeLoweringVisitor(&getContext(), flattenAggregateMemData,
preserveAggregate, preservePublicTypes,
symTbl, cache);
auto tl = TypeLoweringVisitor(&getContext(), preserveAggregate,
preservePublicTypes, symTbl, cache);
tl.lowerModule(op);
std::lock_guard<std::mutex> lg(nlaAppendLock);
@ -1723,9 +1572,10 @@ void LowerTypesPass::runOnOperation() {
}
/// This is the pass constructor.
std::unique_ptr<mlir::Pass> circt::firrtl::createLowerFIRRTLTypesPass(
bool replSeqMem, bool preserveAggregate, bool preservePublicTypes) {
std::unique_ptr<mlir::Pass>
circt::firrtl::createLowerFIRRTLTypesPass(bool preserveAggregate,
bool preservePublicTypes) {
return std::make_unique<LowerTypesPass>(replSeqMem, preserveAggregate,
return std::make_unique<LowerTypesPass>(preserveAggregate,
preservePublicTypes);
}

View File

@ -0,0 +1,150 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl.module(firrtl-lower-memory))' %s | FileCheck %s
firrtl.circuit "Mem" {
firrtl.module public @Mem(in %clock: !firrtl.clock, in %rAddr: !firrtl.uint<4>, in %rEn: !firrtl.uint<1>, out %rData: !firrtl.bundle<a: uint<8>, b: uint<8>>, in %wAddr: !firrtl.uint<4>, in %wEn: !firrtl.uint<1>, in %wMask: !firrtl.bundle<a: uint<1>, b: uint<1>>, in %wData: !firrtl.bundle<a: uint<8>, b: uint<8>>) {
%memory_r, %memory_w = firrtl.mem Undefined {depth = 16 : i64, name = "memory", portNames = ["r", "w"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>, !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>
%0 = firrtl.subfield %memory_r(2) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.clock
firrtl.strictconnect %0, %clock : !firrtl.clock
%1 = firrtl.subfield %memory_r(1) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.uint<1>
firrtl.strictconnect %1, %rEn : !firrtl.uint<1>
%2 = firrtl.subfield %memory_r(0) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.uint<4>
firrtl.strictconnect %2, %rAddr : !firrtl.uint<4>
%3 = firrtl.subfield %memory_r(3) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.bundle<a: uint<8>, b: uint<8>>
firrtl.strictconnect %rData, %3 : !firrtl.bundle<a: uint<8>, b: uint<8>>
%4 = firrtl.subfield %memory_w(2) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.clock
firrtl.strictconnect %4, %clock : !firrtl.clock
%5 = firrtl.subfield %memory_w(1) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<1>
firrtl.strictconnect %5, %wEn : !firrtl.uint<1>
%6 = firrtl.subfield %memory_w(0) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<4>
firrtl.strictconnect %6, %wAddr : !firrtl.uint<4>
%7 = firrtl.subfield %memory_w(4) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
firrtl.strictconnect %7, %wMask : !firrtl.bundle<a: uint<1>, b: uint<1>>
%8 = firrtl.subfield %memory_w(3) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<8>, b: uint<8>>
firrtl.strictconnect %8, %wData : !firrtl.bundle<a: uint<8>, b: uint<8>>
// ---------------------------------------------------------------------------------
// After flattenning the memory data
// CHECK: %[[memory_r:.+]], %[[memory_w:.+]] = firrtl.mem Undefined {depth = 16 : i64, name = "memory", portNames = ["r", "w"], readLatency = 0 : i32, writeLatency = 1 : i32}
// CHECK-SAME: !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: uint<16>>, !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<16>, mask: uint<2>>
// CHECK: %[[memory_r_0:.+]] = firrtl.wire {name = "memory_r"} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>
// CHECK: %[[v0:.+]] = firrtl.subfield %[[memory_r]](0)
// CHECK: firrtl.strictconnect %[[v0]], %[[memory_r_addr:.+]] :
// CHECK: %[[v1:.+]] = firrtl.subfield %[[memory_r]](1)
// CHECK: firrtl.strictconnect %[[v1]], %[[memory_r_en:.+]] :
// CHECK: %[[v2:.+]] = firrtl.subfield %[[memory_r]](2)
// CHECK: firrtl.strictconnect %[[v2]], %[[memory_r_clk:.+]] :
// CHECK: %[[v3:.+]] = firrtl.subfield %[[memory_r]](3)
//
// ---------------------------------------------------------------------------------
// Read ports
// CHECK: %[[v4:.+]] = firrtl.bitcast %[[v3]] : (!firrtl.uint<16>) -> !firrtl.bundle<a: uint<8>, b: uint<8>>
// CHECK: firrtl.strictconnect %[[memory_r_data:.+]], %[[v4]] :
// --------------------------------------------------------------------------------
// Write Ports
// CHECK: %[[memory_w_1:.+]] = firrtl.wire {name = "memory_w"} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>
// CHECK: %[[v9:.+]] = firrtl.subfield %[[memory_w]](3)
// CHECK: %[[v17:.+]] = firrtl.bitcast %[[v15:.+]] : (!firrtl.bundle<a: uint<8>, b: uint<8>>) -> !firrtl.uint<16>
// CHECK: firrtl.strictconnect %[[v9]], %[[v17]]
//
// --------------------------------------------------------------------------------
// Mask Ports
// CHECK: %[[v11:.+]] = firrtl.subfield %[[memory_w]](4)
// CHECK: %[[v12:.+]] = firrtl.bitcast %[[v18:.+]] : (!firrtl.bundle<a: uint<1>, b: uint<1>>) -> !firrtl.uint<2>
// CHECK: firrtl.strictconnect %[[v11]], %[[v12]]
// --------------------------------------------------------------------------------
// Connections to module ports
// CHECK: %[[v21:.+]] = firrtl.subfield %[[memory_r_0]](2) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.clock
// CHECK: firrtl.strictconnect %[[v21]], %clock :
// CHECK: %[[v22:.+]] = firrtl.subfield %[[memory_r_0]](1) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v22]], %rEn : !firrtl.uint<1>
// CHECK: %[[v23:.+]] = firrtl.subfield %[[memory_r_0]](0) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.uint<4>
// CHECK: firrtl.strictconnect %[[v23]], %rAddr : !firrtl.uint<4>
// CHECK: %[[v24:.+]] = firrtl.subfield %[[memory_r_0]](3) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data flip: bundle<a: uint<8>, b: uint<8>>>) -> !firrtl.bundle<a: uint<8>, b: uint<8>>
// CHECK: firrtl.strictconnect %rData, %[[v24]] : !firrtl.bundle<a: uint<8>, b: uint<8>>
// CHECK: %[[v25:.+]] = firrtl.subfield %[[memory_w_1]](2) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.clock
// CHECK: firrtl.strictconnect %[[v25]], %clock : !firrtl.clock
// CHECK: %[[v26:.+]] = firrtl.subfield %[[memory_w_1]](1) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v26]], %wEn : !firrtl.uint<1>
// CHECK: %[[v27:.+]] = firrtl.subfield %[[memory_w_1]](0) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<4>
// CHECK: firrtl.strictconnect %[[v27]], %wAddr : !firrtl.uint<4>
// CHECK: %[[v28:.+]] = firrtl.subfield %[[memory_w_1]](4) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK: firrtl.strictconnect %[[v28]], %wMask : !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK: %[[v29:.+]] = firrtl.subfield %[[memory_w_1]](3) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: bundle<a: uint<8>, b: uint<8>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<8>, b: uint<8>>
// CHECK: firrtl.strictconnect %[[v29]], %wData : !firrtl.bundle<a: uint<8>, b: uint<8>>
}
firrtl.module @MemoryRWSplit(in %clock: !firrtl.clock, in %rwEn: !firrtl.uint<1>, in %rwMode: !firrtl.uint<1>, in %rwAddr: !firrtl.uint<4>, in %rwMask: !firrtl.bundle<a: uint<1>, b: uint<1>>, in %rwDataIn: !firrtl.bundle<a: uint<8>, b: uint<9>>, out %rwDataOut: !firrtl.bundle<a: uint<8>, b: uint<9>>) {
%memory_rw = firrtl.mem Undefined {depth = 16 : i64, groupID = 1 : ui32, name = "memory", portNames = ["rw"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>
// CHECK: %memory_rw = firrtl.mem Undefined {depth = 16 : i64, groupID = 1 : ui32, name = "memory", portNames = ["rw"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: uint<17>, wmode: uint<1>, wdata: uint<17>, wmask: uint<17>>
// CHECK: %[[memory_rw_0:.+]] = firrtl.wire {name = "memory_rw"} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>
%0 = firrtl.subfield %memory_rw(3) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<8>, b: uint<9>>
%1 = firrtl.subfield %memory_rw(5) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<8>, b: uint<9>>
%2 = firrtl.subfield %memory_rw(6) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
%3 = firrtl.subfield %memory_rw(4) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<1>
%4 = firrtl.subfield %memory_rw(0) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<4>
%5 = firrtl.subfield %memory_rw(1) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.uint<1>
%6 = firrtl.subfield %memory_rw(2) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: bundle<a: uint<8>, b: uint<9>>, wmode: uint<1>, wdata: bundle<a: uint<8>, b: uint<9>>, wmask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.clock
firrtl.connect %6, %clock : !firrtl.clock, !firrtl.clock
firrtl.connect %5, %rwEn : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %4, %rwAddr : !firrtl.uint<4>, !firrtl.uint<4>
firrtl.connect %3, %rwMode : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %2, %rwMask : !firrtl.bundle<a: uint<1>, b: uint<1>>, !firrtl.bundle<a: uint<1>, b: uint<1>>
firrtl.connect %1, %rwDataIn : !firrtl.bundle<a: uint<8>, b: uint<9>>, !firrtl.bundle<a: uint<8>, b: uint<9>>
firrtl.connect %rwDataOut, %0 : !firrtl.bundle<a: uint<8>, b: uint<9>>, !firrtl.bundle<a: uint<8>, b: uint<9>>
// CHECK: %[[v6:.+]] = firrtl.subfield %[[memory_rw_0]](3) :
// CHECK: %[[v7:.+]] = firrtl.subfield %memory_rw(3) :
// CHECK: %[[v8:.+]] = firrtl.bitcast %[[v7]] :
// CHECK: firrtl.strictconnect %[[v6]], %[[v8]] :
// CHECK: %[[v9:.+]] = firrtl.subfield %[[memory_rw_0]](4) :
// CHECK: %[[v10:.+]] = firrtl.subfield %memory_rw(4) :
// CHECK: firrtl.strictconnect %[[v10]], %[[v9]] : !firrtl.uint<1>
// CHECK: %[[v11:.+]] = firrtl.subfield %[[memory_rw_0]](5) :
// CHECK: %[[v12:.+]] = firrtl.subfield %memory_rw(5) :
// CHECK: %[[v13:.+]] = firrtl.bitcast %[[v11]] : (!firrtl.bundle<a: uint<8>, b: uint<9>>) -> !firrtl.uint<17>
// CHECK: firrtl.strictconnect %[[v12]], %[[v13]] :
// CHECK: %[[v14:.+]] = firrtl.subfield %[[memory_rw_0]](6) :
// CHECK: %[[v15:.+]] = firrtl.subfield %memory_rw(6) : (!firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, rdata flip: uint<17>, wmode: uint<1>, wdata: uint<17>, wmask: uint<17>>) -> !firrtl.uint<17>
// CHECK: %[[v16:.+]] = firrtl.bitcast %14 : (!firrtl.bundle<a: uint<1>, b: uint<1>>) -> !firrtl.uint<2>
// CHECK: %[[v17:.+]] = firrtl.bits %16 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1>
// CHECK: %[[v18:.+]] = firrtl.cat %[[v17]], %[[v17]] : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2>
// CHECK: %[[v19:.+]] = firrtl.cat %[[v17]], %[[v18]] : (!firrtl.uint<1>, !firrtl.uint<2>) -> !firrtl.uint<3>
// CHECK: %[[v24:.+]] = firrtl.cat %[[v17]], %[[v23:.+]] : (!firrtl.uint<1>, !firrtl.uint<7>) -> !firrtl.uint<8>
// CHECK: %[[v25:.+]] = firrtl.bits %16 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1>
// CHECK: %[[v26:.+]] = firrtl.cat %[[v25]], %[[v24]] : (!firrtl.uint<1>, !firrtl.uint<8>) -> !firrtl.uint<9>
// CHECK: %[[v27:.+]] = firrtl.cat %[[v25]], %[[v26]] : (!firrtl.uint<1>, !firrtl.uint<9>) -> !firrtl.uint<10>
// CHECK: %[[v28:.+]] = firrtl.cat %[[v25]], %[[v27]] : (!firrtl.uint<1>, !firrtl.uint<10>) -> !firrtl.uint<11>
// CHECK: %[[v34:.+]] = firrtl.cat %[[v25]], %[[v33:.+]] : (!firrtl.uint<1>, !firrtl.uint<16>) -> !firrtl.uint<17>
// CHECK: firrtl.strictconnect %[[v15]], %[[v34]] :
// Ensure 0 bit fields are handled properly.
%ram_MPORT = firrtl.mem Undefined {depth = 4 : i64, groupID = 1 : ui32, name = "ram", portNames = ["MPORT"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<2>, en: uint<1>, clk: clock, data: bundle<entry: bundle<a: uint<0>, b: uint<1>, c: uint<2>>>, mask: bundle<entry: bundle<a: uint<1>, b: uint<1>, c: uint<1>>>>
// CHECK: %ram_MPORT = firrtl.mem Undefined {depth = 4 : i64, groupID = 1 : ui32, name = "ram", portNames = ["MPORT"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<2>, en: uint<1>, clk: clock, data: uint<3>, mask: uint<3>>
}
firrtl.module @ZeroBitMasks(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %io: !firrtl.bundle<a: uint<0>, b: uint<20>>) {
%invalid = firrtl.invalidvalue : !firrtl.bundle<a: uint<1>, b: uint<1>>
%invalid_0 = firrtl.invalidvalue : !firrtl.bundle<a: uint<0>, b: uint<20>>
%ram_MPORT = firrtl.mem Undefined {depth = 1 : i64, groupID = 1 : ui32, name = "ram", portNames = ["MPORT"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>
%3 = firrtl.subfield %ram_MPORT(3) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<0>, b: uint<20>>
firrtl.strictconnect %3, %invalid_0 : !firrtl.bundle<a: uint<0>, b: uint<20>>
%4 = firrtl.subfield %ram_MPORT(4) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
firrtl.strictconnect %4, %invalid : !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK: %ram_MPORT = firrtl.mem Undefined {depth = 1 : i64, groupID = 1 : ui32, name = "ram", portNames = ["MPORT"], readLatency = 0 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<20>, mask: uint<1>>
// CHECK: %ram_MPORT_1 = firrtl.wire {name = "ram_MPORT"} : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>
// CHECK: %[[v6:.+]] = firrtl.subfield %ram_MPORT_1(3) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<0>, b: uint<20>>
// CHECK: %[[v7:.+]] = firrtl.subfield %ram_MPORT(3) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<20>, mask: uint<1>>) -> !firrtl.uint<20>
// CHECK: %[[v8:.+]] = firrtl.bitcast %6 : (!firrtl.bundle<a: uint<0>, b: uint<20>>) -> !firrtl.uint<20>
// CHECK: firrtl.strictconnect %7, %8 : !firrtl.uint<20>
// CHECK: %[[v9:.+]] = firrtl.subfield %ram_MPORT_1(4) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK: %[[v10:.+]] = firrtl.subfield %ram_MPORT(4) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<20>, mask: uint<1>>) -> !firrtl.uint<1>
// CHECK: %[[v11:.+]] = firrtl.bitcast %9 : (!firrtl.bundle<a: uint<1>, b: uint<1>>) -> !firrtl.uint<2>
// CHECK: %[[v12:.+]] = firrtl.bits %11 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1>
// CHECK: %[[v13:.+]] = firrtl.bits %11 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v10]], %[[v13]] : !firrtl.uint<1>
// CHECK: %[[v14:.+]] = firrtl.subfield %ram_MPORT_1(3) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<0>, b: uint<20>>
// CHECK: firrtl.strictconnect %[[v14]], %invalid_0 : !firrtl.bundle<a: uint<0>, b: uint<20>>
// CHECK: %[[v15:.+]] = firrtl.subfield %ram_MPORT_1(4) : (!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: bundle<a: uint<0>, b: uint<20>>, mask: bundle<a: uint<1>, b: uint<1>>>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
firrtl.connect %3, %io : !firrtl.bundle<a: uint<0>, b: uint<20>>, !firrtl.bundle<a: uint<0>, b: uint<20>>
}
}

View File

@ -1,5 +1,4 @@
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types)' %s | FileCheck --check-prefixes=CHECK,COMMON %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{flatten-mem=true})' %s | FileCheck --check-prefix=FLATTEN %s
// RUN: circt-opt -pass-pipeline='firrtl.circuit(firrtl-lower-types{preserve-aggregate=true})' %s | FileCheck --check-prefixes=AGGREGATE,COMMON %s

View File

@ -466,11 +466,14 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
if (blackBoxMemory)
pm.nest<firrtl::CircuitOp>().addPass(firrtl::createBlackBoxMemoryPass());
if (replSeqMem)
pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
firrtl::createLowerMemoryPass());
// The input mlir file could be firrtl dialect so we might need to clean
// things up.
if (lowerTypes) {
pm.addNestedPass<firrtl::CircuitOp>(firrtl::createLowerFIRRTLTypesPass(
replSeqMem, preserveAggregate, preservePublicTypes));
preserveAggregate, preservePublicTypes));
// Only enable expand whens if lower types is also enabled.
if (expandWhens) {
auto &modulePM = pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>();