[FIRRTL][InferRW] Set RWmode to the complement term in Write enable (#3759)

The `InferRW` pass transforms a memory with read and write ports to a single 
ReadWrite port memory, if it can prove that the read and write enable are
 mutually exclusive. The algorithm checks if any of the terms in the `And`
 expression tree of read and write enable is a complement of each other, to
infer if the read and write enable are trivially mutually exclusive.
The `RWmode` of the ReadWrite memory is set to `1` to use the memory in
 write mode and `0` for read mode. 
 
This PR sets the `RWmode` to the term in the `And` expression tree of
 the write enable, which proves the mutual exclusion, instead of
 setting it to the write enable. This is done to ensure equivalence with the
 `firrtl` compiler.
 
For example, if, `write enable = A && B`, `read enable = C && ~B`
 implies `RWmode = B`.
This commit is contained in:
Prithayan Barua 2022-08-23 12:44:16 -07:00 committed by GitHub
parent 0d82b4bb24
commit 7c11a9fa01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 27 deletions

View File

@ -55,12 +55,12 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
continue;
Value rClock, wClock;
// The memory has exactly two ports.
SmallVector<Value> portTerms[2];
SmallVector<Value> readTerms, writeTerms;
for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
// Get the port value.
Value portVal = portIt.value();
// Get the port kind.
bool readPort =
bool isReadPort =
memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
// Iterate over all users of the port.
for (Operation *u : portVal.getUsers())
@ -72,9 +72,10 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
// If this is the enable field, record the product terms(the And
// expression tree).
if (fName.equals("en"))
getProductTerms(sf, portTerms[portIt.index()]);
getProductTerms(sf, isReadPort ? readTerms : writeTerms);
else if (fName.equals("clk")) {
if (readPort)
if (isReadPort)
rClock = getConnectSrc(sf);
else
wClock = getConnectSrc(sf);
@ -90,17 +91,19 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
llvm::dbgs() << "\n read clock:" << rClock
<< " --- write clock:" << wClock;
llvm::dbgs() << "\n Read terms==>"; for (auto t
: portTerms[0]) llvm::dbgs()
: readTerms) llvm::dbgs()
<< "\n term::" << t;
llvm::dbgs() << "\n Write terms==>"; for (auto t
: portTerms[1]) llvm::dbgs()
: writeTerms) llvm::dbgs()
<< "\n term::" << t;
);
// If the read and write clocks are the same, check if any of the product
// terms are a complement of each other.
if (!checkComplement(portTerms))
// If the read and write clocks are the same, and if any of the write
// enable product terms are a complement of the read enable, then return
// the write enable term.
auto complementTerm = checkComplement(readTerms, writeTerms);
if (!complementTerm)
continue;
SmallVector<Attribute, 4> resultNames;
@ -158,14 +161,14 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
// Enable = Or(WriteEnable, ReadEnable).
builder.create<StrictConnectOp>(
enb, builder.create<OrPrimOp>(rEnWire, wEnWire));
// WriteMode = WriteEnable.
builder.create<StrictConnectOp>(wmode, wEnWire);
builder.setInsertionPointToEnd(wmode->getBlock());
builder.create<StrictConnectOp>(wmode, complementTerm);
// Now iterate over the original memory read and write ports.
for (const auto &portIt : llvm::enumerate(memOp.getResults())) {
// Get the port value.
Value portVal = portIt.value();
// Get the port kind.
bool readPort =
bool isReadPort =
memOp.getPortKind(portIt.index()) == MemOp::PortKind::Read;
// Iterate over all users of the port, which are the subfield ops, and
// replace them.
@ -175,7 +178,7 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
sf.getInput().getType().cast<BundleType>().getElementName(
sf.getFieldIndex());
Value repl;
if (readPort)
if (isReadPort)
repl = llvm::StringSwitch<Value>(fName)
.Case("en", rEnWire)
.Case("clk", clk)
@ -272,27 +275,28 @@ private:
}
}
/// Check if any of the terms in the prodTerms[0] is a complement of any of
/// the terms in prodTerms[1]. prodTerms[0], prodTerms[1] is a vector of
/// Value, each of which correspond to the two product terms of read/write
/// enable.
bool checkComplement(SmallVector<Value> prodTerms[2]) {
bool isComplement = false;
/// If any of the terms in the read enable, prodTerms[0] is a complement of
/// any of the terms in the write enable prodTerms[1], return the
/// corresponding write enable term. prodTerms[0], prodTerms[1] is a vector of
/// Value, each of which correspond to the two product terms of read and write
/// enable respectively.
Value checkComplement(const SmallVector<Value> &readTerms,
const SmallVector<Value> &writeTerms) {
// Foreach Value in first term, check if it is the complement of any of the
// Value in second term.
for (auto t1 : prodTerms[0])
for (auto t2 : prodTerms[1]) {
// Return true if t1 is a Not of t2.
for (auto t1 : readTerms)
for (auto t2 : writeTerms) {
// Return t2, t1 is a Not of t2.
if (!t1.isa<BlockArgument>() && isa<NotPrimOp>(t1.getDefiningOp()))
if (cast<NotPrimOp>(t1.getDefiningOp()).getInput() == t2)
return true;
// Else Return true if t2 is a Not of t1.
return t2;
// Else Return t2, if t2 is a Not of t1.
if (!t2.isa<BlockArgument>() && isa<NotPrimOp>(t2.getDefiningOp()))
if (cast<NotPrimOp>(t2.getDefiningOp()).getInput() == t1)
return true;
return t2;
}
return isComplement;
return {};
}
void inferUnmasked(MemOp &memOp, SmallVector<Operation *> &opsToErase) {

View File

@ -40,13 +40,13 @@ firrtl.circuit "TLRAM" {
// CHECK: firrtl.strictconnect %[[v0:.+]], %[[v7]] : !firrtl.uint<4>
// CHECK: %[[v8:.+]] = firrtl.or %[[readEnable:.+]], %[[writeEnable]] : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v1:.+]], %[[v8]] : !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v4:.+]], %[[writeEnable]]
// CHECK: firrtl.connect %[[readAddr]], %[[index2:.+]] : !firrtl.uint<4>, !firrtl.uint<4>
// CHECK: firrtl.connect %[[readEnable]], %mem_MPORT_en : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.connect %[[writeAddr]], %index : !firrtl.uint<4>, !firrtl.uint<4>
// CHECK: firrtl.connect %[[writeEnable]], %wen : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: %[[v10:.+]] = firrtl.not %wen : (!firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: firrtl.connect %mem_MPORT_en, %[[v10]] : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.strictconnect %[[v4:.+]], %wen : !firrtl.uint<1>
}
// Test the pattern of enable with Mux (sel, high, 0)
@ -76,6 +76,7 @@ firrtl.circuit "TLRAM" {
firrtl.connect %5, %io_addr : !firrtl.uint<11>, !firrtl.uint<11>
firrtl.connect %7, %clock : !firrtl.clock, !firrtl.clock
firrtl.connect %io_dataOut, %8 : !firrtl.uint<32>, !firrtl.uint<32>
// CHECK: firrtl.strictconnect %4, %io_wen : !firrtl.uint<1>
}
// Test the pattern of enable with an And tree and Mux (sel, high, 0)
@ -111,6 +112,7 @@ firrtl.circuit "TLRAM" {
firrtl.connect %5, %io_addr : !firrtl.uint<11>, !firrtl.uint<11>
firrtl.connect %7, %clock : !firrtl.clock, !firrtl.clock
firrtl.connect %io_dataOut, %8 : !firrtl.uint<32>, !firrtl.uint<32>
// CHECK: firrtl.strictconnect %4, %io_write : !firrtl.uint<1>
}
// Cannot merge read and write, since the pattern is enable = Mux (sel, high, 1)
@ -205,6 +207,7 @@ firrtl.circuit "TLRAM" {
firrtl.connect %6, %clock : !firrtl.clock, !firrtl.clock
firrtl.connect %8, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %7, %io_wdata : !firrtl.uint<32>, !firrtl.uint<32>
// CHECK: firrtl.strictconnect %4, %io_wen : !firrtl.uint<1>
}
// Check for indirect connection to clock