mirror of https://github.com/llvm/circt.git
[FIRRTL] Add RemoveResets Pass (#2287)
Add a new pass, RemoveResets, that replaces RegResetOps that have invalidated initialization values with RegOps. This is part of a series of patches that are intended to align CIRCT with the Scala FIRRTL Compiler (SFC) interpretation of invalid. Previously, CIRCT relies on canonicalization/folding of invalid values to do this optimization. This pass enables future canonicalization/folding of invalid values to zero (as the SFC does) without having to worry about performing this optimization. Run the RemoveResets pass as part of firtool after ExpandWhens and before the first canonicalization. This enables conversion of invalidated RegResetOps to RegOps before canonicalization (eventually) interprets invalid values as zero. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
This commit is contained in:
parent
635786e1cb
commit
b31ce1e65a
|
@ -69,6 +69,8 @@ std::unique_ptr<mlir::Pass> createGrandCentralSignalMappingsPass();
|
|||
|
||||
std::unique_ptr<mlir::Pass> createCheckCombCyclesPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createRemoveResetsPass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "circt/Dialect/FIRRTL/Passes.h.inc"
|
||||
|
|
|
@ -331,4 +331,9 @@ def CheckCombCycles : Pass<"firrtl-check-comb-cycles", "firrtl::CircuitOp"> {
|
|||
let constructor = "circt::firrtl::createCheckCombCyclesPass()";
|
||||
}
|
||||
|
||||
def RemoveResets : Pass<"firrtl-remove-resets", "firrtl::FModuleOp"> {
|
||||
let summary = "Remove module-scoped invalidated resets";
|
||||
let constructor = "circt::firrtl::createRemoveResetsPass()";
|
||||
}
|
||||
|
||||
#endif // CIRCT_DIALECT_FIRRTL_PASSES_TD
|
||||
|
|
|
@ -17,10 +17,11 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
|
|||
ModuleInliner.cpp
|
||||
PrefixModules.cpp
|
||||
PrintInstanceGraph.cpp
|
||||
|
||||
RemoveResets.cpp
|
||||
|
||||
DEPENDS
|
||||
CIRCTFIRRTLTransformsIncGen
|
||||
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
CIRCTFIRRTL
|
||||
CIRCTHW
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
//===- RemoveResets.cpp - Remove resets of invalid value --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This pass converts registers that are reset to an invalid value to resetless
|
||||
// registers. This is a reduced implementation of the Scala FIRRTL Compiler's
|
||||
// RemoveResets pass.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetails.h"
|
||||
#include "circt/Dialect/FIRRTL/Passes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "firrtl-remove-resets"
|
||||
|
||||
using namespace circt;
|
||||
using namespace firrtl;
|
||||
|
||||
struct RemoveResetsPass : public RemoveResetsBase<RemoveResetsPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Returns true if this value is invalidated. This requires that a value is
|
||||
// only ever driven once. This is guaranteed if this runs after the
|
||||
// `ExpandWhens` pass .
|
||||
static bool isInvalid(Value val) {
|
||||
|
||||
// Update `val` to the source of the connection driving `thisVal`. This walks
|
||||
// backwards across users to find the first connection and updates `val` to
|
||||
// the source. This assumes that only one connect is driving `thisVal`, i.e.,
|
||||
// this pass runs after `ExpandWhens`.
|
||||
auto updateVal = [&](Value thisVal) {
|
||||
for (auto *user : thisVal.getUsers()) {
|
||||
if (auto connect = dyn_cast<ConnectOp>(user)) {
|
||||
if (connect.dest() != val)
|
||||
continue;
|
||||
val = connect.src();
|
||||
return;
|
||||
}
|
||||
}
|
||||
val = nullptr;
|
||||
return;
|
||||
};
|
||||
|
||||
while (val) {
|
||||
// The value is a port.
|
||||
if (auto blockArg = val.dyn_cast<BlockArgument>()) {
|
||||
FModuleOp op = cast<FModuleOp>(val.getParentBlock()->getParentOp());
|
||||
auto direction = op.getPortDirection(blockArg.getArgNumber());
|
||||
// Base case: this is an input port and cannot be invalidated in module
|
||||
// scope.
|
||||
if (direction == Direction::In)
|
||||
return false;
|
||||
updateVal(blockArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *op = val.getDefiningOp();
|
||||
|
||||
// The value is an instance port.
|
||||
if (auto inst = dyn_cast<InstanceOp>(op)) {
|
||||
auto resultNo = val.cast<OpResult>().getResultNumber();
|
||||
// An output port of an instance crosses a module boundary. This is not
|
||||
// invalid within module scope.
|
||||
if (inst.getPortDirection(resultNo) == Direction::Out)
|
||||
return false;
|
||||
updateVal(val);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Base case: we found an invalid value. We're done, return true.
|
||||
if (isa<InvalidValueOp>(op))
|
||||
return true;
|
||||
|
||||
// Base case: we hit something that is NOT a wire, e.g., a PrimOp. We're
|
||||
// done, return false.
|
||||
if (!isa<WireOp>(op))
|
||||
return false;
|
||||
|
||||
// Update `val` with the driver of the wire. If no driver found, `val` will
|
||||
// be set to nullptr and we exit on the next while iteration.
|
||||
updateVal(op->getResult(0));
|
||||
};
|
||||
return false;
|
||||
};
|
||||
|
||||
void RemoveResetsPass::runOnOperation() {
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs() << "===----- Running RemoveResets "
|
||||
"-----------------------------------------------===\n"
|
||||
<< "Module: '" << getOperation().getName() << "'\n";);
|
||||
|
||||
bool madeModifications = false;
|
||||
for (auto reg : llvm::make_early_inc_range(
|
||||
getOperation().getBody()->getOps<RegResetOp>())) {
|
||||
|
||||
// If the `RegResetOp` has an invalidated initialization, then replace it
|
||||
// with a `RegOp`.
|
||||
if (isInvalid(reg.resetValue())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " - RegResetOp '" << reg.name()
|
||||
<< "' will be replaced with a RegOp\n");
|
||||
ImplicitLocOpBuilder builder(reg.getLoc(), reg);
|
||||
RegOp newReg =
|
||||
builder.create<RegOp>(reg.getType(), reg.clockVal(), reg.name(),
|
||||
reg.annotations(), reg.inner_symAttr());
|
||||
reg.replaceAllUsesWith(newReg.getResult());
|
||||
reg.erase();
|
||||
madeModifications = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!madeModifications)
|
||||
return markAllAnalysesPreserved();
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> circt::firrtl::createRemoveResetsPass() {
|
||||
return std::make_unique<RemoveResetsPass>();
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
// RUN: circt-opt --pass-pipeline='firrtl.circuit(firrtl.module(firrtl-remove-resets))' --verify-diagnostics --split-input-file %s | FileCheck %s
|
||||
|
||||
firrtl.circuit "RemoveResetTests" {
|
||||
|
||||
firrtl.module @RemoveResetTests() {}
|
||||
|
||||
// An invalidated regreset should be converted to a reg.
|
||||
//
|
||||
// CHECK-LABEL: @InvalidValue
|
||||
firrtl.module @InvalidValue(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
|
||||
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
|
||||
// CHECK: firrtl.reg %clock
|
||||
%r = firrtl.regreset %clock, %reset, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// A regreset invalidated through a wire should be converted to a reg.
|
||||
//
|
||||
// CHECK-LABEL: @InvalidThroughWire
|
||||
firrtl.module @InvalidThroughWire(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
|
||||
%inv = firrtl.wire : !firrtl.uint<1>
|
||||
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
|
||||
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
// CHECK: firrtl.reg %clock
|
||||
%r = firrtl.regreset %clock, %reset, %inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// A regreset invalidated via an output port should be converted to a reg.
|
||||
//
|
||||
// CHECK-LABEL: @InvalidPort
|
||||
firrtl.module @InvalidPort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>, out %x: !firrtl.uint<1>) {
|
||||
%inv = firrtl.wire : !firrtl.uint<1>
|
||||
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
|
||||
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %x, %inv : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
// CHECK: firrtl.reg %clock
|
||||
%r = firrtl.regreset %clock, %reset, %x : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// A regreset invalidate via an instance input port should be converted to a
|
||||
// reg.
|
||||
//
|
||||
// CHECK-LABEL: @InvalidInstancePort
|
||||
firrtl.module @InvalidInstancePort_Submodule(in %inv: !firrtl.uint<1>) {}
|
||||
firrtl.module @InvalidInstancePort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
|
||||
%inv = firrtl.wire : !firrtl.uint<1>
|
||||
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
|
||||
firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
%submodule_inv = firrtl.instance submodule @InvalidInstancePort_Submodule(in inv: !firrtl.uint<1>)
|
||||
firrtl.connect %submodule_inv, %inv : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
// CHECK: firrtl.reg %clock
|
||||
%r = firrtl.regreset %clock, %reset, %submodule_inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// A primitive operation should block invalid propagation.
|
||||
firrtl.module @InvalidPrimop(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) {
|
||||
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
|
||||
%0 = firrtl.not %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.uint<1>
|
||||
// CHECK: firrtl.regreset %clock
|
||||
%r = firrtl.regreset %clock, %reset, %0 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
|
||||
// A regreset invalid value should NOT propagate through a node.
|
||||
firrtl.module @Foo(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<8>, out %q: !firrtl.uint<8>) {
|
||||
%inv = firrtl.wire : !firrtl.uint<8>
|
||||
%invalid_ui8 = firrtl.invalidvalue : !firrtl.uint<8>
|
||||
firrtl.connect %inv, %invalid_ui8 : !firrtl.uint<8>, !firrtl.uint<8>
|
||||
%_T = firrtl.node %inv : !firrtl.uint<8>
|
||||
// CHECK: firrtl.regreset %clock
|
||||
%r = firrtl.regreset %clock, %reset, %_T : !firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>
|
||||
firrtl.connect %r, %d : !firrtl.uint<8>, !firrtl.uint<8>
|
||||
firrtl.connect %q, %r : !firrtl.uint<8>, !firrtl.uint<8>
|
||||
}
|
||||
|
||||
}
|
|
@ -364,6 +364,7 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr,
|
|||
if (expandWhens) {
|
||||
auto &modulePM = pm.nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>();
|
||||
modulePM.addPass(firrtl::createExpandWhensPass());
|
||||
modulePM.addPass(firrtl::createRemoveResetsPass());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue