[FIRRTL] Add RemoveResets Pass (#2287)

Add a new pass, RemoveResets, that replaces RegResetOps that have
invalidated initialization values with RegOps.  This is part of a series
of patches that are intended to align CIRCT with the Scala FIRRTL
Compiler (SFC) interpretation of invalid.  Previously, CIRCT relies on
canonicalization/folding of invalid values to do this optimization.
This pass enables future canonicalization/folding of invalid values to
zero (as the SFC does) without having to worry about performing this
optimization.

Run the RemoveResets pass as part of firtool after ExpandWhens and
before the first canonicalization.  This enables conversion of
invalidated RegResetOps to RegOps before canonicalization (eventually)
interprets invalid values as zero.

Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
This commit is contained in:
Schuyler Eldridge 2021-12-06 16:13:27 -05:00 committed by GitHub
parent 635786e1cb
commit b31ce1e65a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 2 deletions

View File

@ -69,6 +69,8 @@ std::unique_ptr<mlir::Pass> createGrandCentralSignalMappingsPass();
std::unique_ptr<mlir::Pass> createCheckCombCyclesPass();
std::unique_ptr<mlir::Pass> createRemoveResetsPass();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "circt/Dialect/FIRRTL/Passes.h.inc"

View File

@ -331,4 +331,9 @@ def CheckCombCycles : Pass<"firrtl-check-comb-cycles", "firrtl::CircuitOp"> {
let constructor = "circt::firrtl::createCheckCombCyclesPass()";
}
def RemoveResets : Pass<"firrtl-remove-resets", "firrtl::FModuleOp"> {
let summary = "Remove module-scoped invalidated resets";
let constructor = "circt::firrtl::createRemoveResetsPass()";
}
#endif // CIRCT_DIALECT_FIRRTL_PASSES_TD

View File

@ -17,6 +17,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
ModuleInliner.cpp
PrefixModules.cpp
PrintInstanceGraph.cpp
RemoveResets.cpp
DEPENDS
CIRCTFIRRTLTransformsIncGen

View File

@ -0,0 +1,124 @@
//===- RemoveResets.cpp - Remove resets of invalid value --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
//
// This pass converts registers that are reset to an invalid value to resetless
// registers. This is a reduced implementation of the Scala FIRRTL Compiler's
// RemoveResets pass.
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "firrtl-remove-resets"
using namespace circt;
using namespace firrtl;
struct RemoveResetsPass : public RemoveResetsBase<RemoveResetsPass> {
void runOnOperation() override;
};
// Returns true if this value is invalidated. This requires that a value is
// only ever driven once. This is guaranteed if this runs after the
// `ExpandWhens` pass .
static bool isInvalid(Value val) {
// Update `val` to the source of the connection driving `thisVal`. This walks
// backwards across users to find the first connection and updates `val` to
// the source. This assumes that only one connect is driving `thisVal`, i.e.,
// this pass runs after `ExpandWhens`.
auto updateVal = [&](Value thisVal) {
for (auto *user : thisVal.getUsers()) {
if (auto connect = dyn_cast<ConnectOp>(user)) {
if (connect.dest() != val)
continue;
val = connect.src();
return;
}
}
val = nullptr;
return;
};
while (val) {
// The value is a port.
if (auto blockArg = val.dyn_cast<BlockArgument>()) {
FModuleOp op = cast<FModuleOp>(val.getParentBlock()->getParentOp());
auto direction = op.getPortDirection(blockArg.getArgNumber());
// Base case: this is an input port and cannot be invalidated in module
// scope.
if (direction == Direction::In)
return false;
updateVal(blockArg);
continue;
}
auto *op = val.getDefiningOp();
// The value is an instance port.
if (auto inst = dyn_cast<InstanceOp>(op)) {
auto resultNo = val.cast<OpResult>().getResultNumber();
// An output port of an instance crosses a module boundary. This is not
// invalid within module scope.
if (inst.getPortDirection(resultNo) == Direction::Out)
return false;
updateVal(val);
continue;
}
// Base case: we found an invalid value. We're done, return true.
if (isa<InvalidValueOp>(op))
return true;
// Base case: we hit something that is NOT a wire, e.g., a PrimOp. We're
// done, return false.
if (!isa<WireOp>(op))
return false;
// Update `val` with the driver of the wire. If no driver found, `val` will
// be set to nullptr and we exit on the next while iteration.
updateVal(op->getResult(0));
};
return false;
};
void RemoveResetsPass::runOnOperation() {
LLVM_DEBUG(
llvm::dbgs() << "===----- Running RemoveResets "
"-----------------------------------------------===\n"
<< "Module: '" << getOperation().getName() << "'\n";);
bool madeModifications = false;
for (auto reg : llvm::make_early_inc_range(
getOperation().getBody()->getOps<RegResetOp>())) {
// If the `RegResetOp` has an invalidated initialization, then replace it
// with a `RegOp`.
if (isInvalid(reg.resetValue())) {
LLVM_DEBUG(llvm::dbgs() << " - RegResetOp '" << reg.name()
<< "' will be replaced with a RegOp\n");
ImplicitLocOpBuilder builder(reg.getLoc(), reg);
RegOp newReg =
builder.create<RegOp>(reg.getType(), reg.clockVal(), reg.name(),
reg.annotations(), reg.inner_symAttr());
reg.replaceAllUsesWith(newReg.getResult());
reg.erase();
madeModifications = true;
}
}
if (!madeModifications)
return markAllAnalysesPreserved();
}
std::unique_ptr<mlir::Pass> circt::firrtl::createRemoveResetsPass() {
return std::make_unique<RemoveResetsPass>();
}

View File

@ -0,0 +1,84 @@
// RUN: circt-opt --pass-pipeline='firrtl.circuit(firrtl.module(firrtl-remove-resets))' --verify-diagnostics --split-input-file %s | FileCheck %s
firrtl.circuit "RemoveResetTests" {
firrtl.module @RemoveResetTests() {}
// An invalidated regreset should be converted to a reg.
//
// CHECK-LABEL: @InvalidValue
firrtl.module @InvalidValue(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
// CHECK: firrtl.reg %clock
%r = firrtl.regreset %clock, %reset, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
}
// A regreset invalidated through a wire should be converted to a reg.
//
// CHECK-LABEL: @InvalidThroughWire
firrtl.module @InvalidThroughWire(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
%inv = firrtl.wire : !firrtl.uint<1>
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.reg %clock
%r = firrtl.regreset %clock, %reset, %inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
}
// A regreset invalidated via an output port should be converted to a reg.
//
// CHECK-LABEL: @InvalidPort
firrtl.module @InvalidPort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>, out %x: !firrtl.uint<1>) {
%inv = firrtl.wire : !firrtl.uint<1>
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %x, %inv : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.reg %clock
%r = firrtl.regreset %clock, %reset, %x : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
}
// A regreset invalidate via an instance input port should be converted to a
// reg.
//
// CHECK-LABEL: @InvalidInstancePort
firrtl.module @InvalidInstancePort_Submodule(in %inv: !firrtl.uint<1>) {}
firrtl.module @InvalidInstancePort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
%inv = firrtl.wire : !firrtl.uint<1>
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
%submodule_inv = firrtl.instance submodule @InvalidInstancePort_Submodule(in inv: !firrtl.uint<1>)
firrtl.connect %submodule_inv, %inv : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.reg %clock
%r = firrtl.regreset %clock, %reset, %submodule_inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
}
// A primitive operation should block invalid propagation.
firrtl.module @InvalidPrimop(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
%0 = firrtl.not %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: firrtl.regreset %clock
%r = firrtl.regreset %clock, %reset, %0 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
}
// A regreset invalid value should NOT propagate through a node.
firrtl.module @Foo(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<8>, out %q: !firrtl.uint<8>) {
%inv = firrtl.wire : !firrtl.uint<8>
%invalid_ui8 = firrtl.invalidvalue : !firrtl.uint<8>
firrtl.connect %inv, %invalid_ui8 : !firrtl.uint<8>, !firrtl.uint<8>
%_T = firrtl.node %inv : !firrtl.uint<8>
// CHECK: firrtl.regreset %clock
%r = firrtl.regreset %clock, %reset, %_T : !firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>
firrtl.connect %r, %d : !firrtl.uint<8>, !firrtl.uint<8>
firrtl.connect %q, %r : !firrtl.uint<8>, !firrtl.uint<8>
}
}

View File

@ -364,6 +364,7 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
if (expandWhens) {
auto &modulePM = pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>();
modulePM.addPass(firrtl::createExpandWhensPass());
modulePM.addPass(firrtl::createRemoveResetsPass());
}
}