From b31ce1e65ab0e9ac68ce9ada85962c897d5b7e2f Mon Sep 17 00:00:00 2001 From: Schuyler Eldridge Date: Mon, 6 Dec 2021 16:13:27 -0500 Subject: [PATCH] [FIRRTL] Add RemoveResets Pass (#2287) Add a new pass, RemoveResets, that replaces RegResetOps that have invalidated initialization values with RegOps. This is part of a series of patches that are intended to align CIRCT with the Scala FIRRTL Compiler (SFC) interpretation of invalid. Previously, CIRCT relies on canonicalization/folding of invalid values to do this optimization. This pass enables future canonicalization/folding of invalid values to zero (as the SFC does) without having to worry about performing this optimization. Run the RemoveResets pass as part of firtool after ExpandWhens and before the first canonicalization. This enables conversion of invalidated RegResetOps to RegOps before canonicalization (eventually) interprets invalid values as zero. Signed-off-by: Schuyler Eldridge --- include/circt/Dialect/FIRRTL/Passes.h | 2 + include/circt/Dialect/FIRRTL/Passes.td | 5 + lib/Dialect/FIRRTL/Transforms/CMakeLists.txt | 5 +- .../FIRRTL/Transforms/RemoveResets.cpp | 124 ++++++++++++++++++ test/Dialect/FIRRTL/remove-resets.mlir | 84 ++++++++++++ tools/firtool/firtool.cpp | 1 + 6 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/FIRRTL/Transforms/RemoveResets.cpp create mode 100644 test/Dialect/FIRRTL/remove-resets.mlir diff --git a/include/circt/Dialect/FIRRTL/Passes.h b/include/circt/Dialect/FIRRTL/Passes.h index da0afc3f0e..c1b0e819cb 100644 --- a/include/circt/Dialect/FIRRTL/Passes.h +++ b/include/circt/Dialect/FIRRTL/Passes.h @@ -69,6 +69,8 @@ std::unique_ptr createGrandCentralSignalMappingsPass(); std::unique_ptr createCheckCombCyclesPass(); +std::unique_ptr createRemoveResetsPass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "circt/Dialect/FIRRTL/Passes.h.inc" diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 33982808de..eff91a12dc 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -331,4 +331,9 @@ def CheckCombCycles : Pass<"firrtl-check-comb-cycles", "firrtl::CircuitOp"> { let constructor = "circt::firrtl::createCheckCombCyclesPass()"; } +def RemoveResets : Pass<"firrtl-remove-resets", "firrtl::FModuleOp"> { + let summary = "Remove module-scoped invalidated resets"; + let constructor = "circt::firrtl::createRemoveResetsPass()"; +} + #endif // CIRCT_DIALECT_FIRRTL_PASSES_TD diff --git a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt index faa73d0863..9b3b1a870e 100755 --- a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt +++ b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt @@ -17,10 +17,11 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms ModuleInliner.cpp PrefixModules.cpp PrintInstanceGraph.cpp - + RemoveResets.cpp + DEPENDS CIRCTFIRRTLTransformsIncGen - + LINK_LIBS PUBLIC CIRCTFIRRTL CIRCTHW diff --git a/lib/Dialect/FIRRTL/Transforms/RemoveResets.cpp b/lib/Dialect/FIRRTL/Transforms/RemoveResets.cpp new file mode 100644 index 0000000000..44ef8eda8a --- /dev/null +++ b/lib/Dialect/FIRRTL/Transforms/RemoveResets.cpp @@ -0,0 +1,124 @@ +//===- RemoveResets.cpp - Remove resets of invalid value --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// +// This pass converts registers that are reset to an invalid value to resetless +// registers. This is a reduced implementation of the Scala FIRRTL Compiler's +// RemoveResets pass. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/FIRRTL/Passes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "firrtl-remove-resets" + +using namespace circt; +using namespace firrtl; + +struct RemoveResetsPass : public RemoveResetsBase { + void runOnOperation() override; +}; + +// Returns true if this value is invalidated. This requires that a value is +// only ever driven once. This is guaranteed if this runs after the +// `ExpandWhens` pass . +static bool isInvalid(Value val) { + + // Update `val` to the source of the connection driving `thisVal`. This walks + // backwards across users to find the first connection and updates `val` to + // the source. This assumes that only one connect is driving `thisVal`, i.e., + // this pass runs after `ExpandWhens`. + auto updateVal = [&](Value thisVal) { + for (auto *user : thisVal.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.dest() != val) + continue; + val = connect.src(); + return; + } + } + val = nullptr; + return; + }; + + while (val) { + // The value is a port. + if (auto blockArg = val.dyn_cast()) { + FModuleOp op = cast(val.getParentBlock()->getParentOp()); + auto direction = op.getPortDirection(blockArg.getArgNumber()); + // Base case: this is an input port and cannot be invalidated in module + // scope. + if (direction == Direction::In) + return false; + updateVal(blockArg); + continue; + } + + auto *op = val.getDefiningOp(); + + // The value is an instance port. + if (auto inst = dyn_cast(op)) { + auto resultNo = val.cast().getResultNumber(); + // An output port of an instance crosses a module boundary. This is not + // invalid within module scope. + if (inst.getPortDirection(resultNo) == Direction::Out) + return false; + updateVal(val); + continue; + } + + // Base case: we found an invalid value. We're done, return true. + if (isa(op)) + return true; + + // Base case: we hit something that is NOT a wire, e.g., a PrimOp. We're + // done, return false. + if (!isa(op)) + return false; + + // Update `val` with the driver of the wire. If no driver found, `val` will + // be set to nullptr and we exit on the next while iteration. + updateVal(op->getResult(0)); + }; + return false; +}; + +void RemoveResetsPass::runOnOperation() { + LLVM_DEBUG( + llvm::dbgs() << "===----- Running RemoveResets " + "-----------------------------------------------===\n" + << "Module: '" << getOperation().getName() << "'\n";); + + bool madeModifications = false; + for (auto reg : llvm::make_early_inc_range( + getOperation().getBody()->getOps())) { + + // If the `RegResetOp` has an invalidated initialization, then replace it + // with a `RegOp`. + if (isInvalid(reg.resetValue())) { + LLVM_DEBUG(llvm::dbgs() << " - RegResetOp '" << reg.name() + << "' will be replaced with a RegOp\n"); + ImplicitLocOpBuilder builder(reg.getLoc(), reg); + RegOp newReg = + builder.create(reg.getType(), reg.clockVal(), reg.name(), + reg.annotations(), reg.inner_symAttr()); + reg.replaceAllUsesWith(newReg.getResult()); + reg.erase(); + madeModifications = true; + } + } + + if (!madeModifications) + return markAllAnalysesPreserved(); +} + +std::unique_ptr circt::firrtl::createRemoveResetsPass() { + return std::make_unique(); +} diff --git a/test/Dialect/FIRRTL/remove-resets.mlir b/test/Dialect/FIRRTL/remove-resets.mlir new file mode 100644 index 0000000000..5f113413a9 --- /dev/null +++ b/test/Dialect/FIRRTL/remove-resets.mlir @@ -0,0 +1,84 @@ +// RUN: circt-opt --pass-pipeline='firrtl.circuit(firrtl.module(firrtl-remove-resets))' --verify-diagnostics --split-input-file %s | FileCheck %s + +firrtl.circuit "RemoveResetTests" { + + firrtl.module @RemoveResetTests() {} + + // An invalidated regreset should be converted to a reg. + // + // CHECK-LABEL: @InvalidValue + firrtl.module @InvalidValue(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) { + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + // CHECK: firrtl.reg %clock + %r = firrtl.regreset %clock, %reset, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1> + } + + // A regreset invalidated through a wire should be converted to a reg. + // + // CHECK-LABEL: @InvalidThroughWire + firrtl.module @InvalidThroughWire(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) { + %inv = firrtl.wire : !firrtl.uint<1> + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.reg %clock + %r = firrtl.regreset %clock, %reset, %inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1> + } + + // A regreset invalidated via an output port should be converted to a reg. + // + // CHECK-LABEL: @InvalidPort + firrtl.module @InvalidPort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>, out %x: !firrtl.uint<1>) { + %inv = firrtl.wire : !firrtl.uint<1> + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %x, %inv : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.reg %clock + %r = firrtl.regreset %clock, %reset, %x : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1> + } + + // A regreset invalidate via an instance input port should be converted to a + // reg. + // + // CHECK-LABEL: @InvalidInstancePort + firrtl.module @InvalidInstancePort_Submodule(in %inv: !firrtl.uint<1>) {} + firrtl.module @InvalidInstancePort(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) { + %inv = firrtl.wire : !firrtl.uint<1> + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + firrtl.connect %inv, %invalid_ui1 : !firrtl.uint<1>, !firrtl.uint<1> + %submodule_inv = firrtl.instance submodule @InvalidInstancePort_Submodule(in inv: !firrtl.uint<1>) + firrtl.connect %submodule_inv, %inv : !firrtl.uint<1>, !firrtl.uint<1> + // CHECK: firrtl.reg %clock + %r = firrtl.regreset %clock, %reset, %submodule_inv : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1> + } + + // A primitive operation should block invalid propagation. + firrtl.module @InvalidPrimop(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<1>, out %q: !firrtl.uint<1>) { + %invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1> + %0 = firrtl.not %invalid_ui1 : (!firrtl.uint<1>) -> !firrtl.uint<1> + // CHECK: firrtl.regreset %clock + %r = firrtl.regreset %clock, %reset, %0 : !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %r, %d : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1> + } + + // A regreset invalid value should NOT propagate through a node. + firrtl.module @Foo(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<8>, out %q: !firrtl.uint<8>) { + %inv = firrtl.wire : !firrtl.uint<8> + %invalid_ui8 = firrtl.invalidvalue : !firrtl.uint<8> + firrtl.connect %inv, %invalid_ui8 : !firrtl.uint<8>, !firrtl.uint<8> + %_T = firrtl.node %inv : !firrtl.uint<8> + // CHECK: firrtl.regreset %clock + %r = firrtl.regreset %clock, %reset, %_T : !firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8> + firrtl.connect %r, %d : !firrtl.uint<8>, !firrtl.uint<8> + firrtl.connect %q, %r : !firrtl.uint<8>, !firrtl.uint<8> + } + +} diff --git a/tools/firtool/firtool.cpp b/tools/firtool/firtool.cpp index 55d63e5a89..7e1a4aa739 100644 --- a/tools/firtool/firtool.cpp +++ b/tools/firtool/firtool.cpp @@ -364,6 +364,7 @@ processBuffer(MLIRContext &context, TimingScope &ts, llvm::SourceMgr &sourceMgr, if (expandWhens) { auto &modulePM = pm.nest().nest(); modulePM.addPass(firrtl::createExpandWhensPass()); + modulePM.addPass(firrtl::createRemoveResetsPass()); } }