From 56a260a1d465a55ba7003ff6a914f7b129a1e01e Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Sat, 16 Jul 2022 21:57:24 +0200 Subject: [PATCH] [FSMToSV] Add FSM to SV conversion pass (#3483) This commit introduces an FSM to SV lowering pass, as well as some small modifications to the FSM dialect to facilitate the conversion. This initial version of the pass does not support transition action regions and variables. The lowering style is fairly straight forward; two processes are emitted, one `always_ff` for state register inference, one `always_comb` for next-state calculation and output assignments. e.g.: ```mlir fsm.machine @top(%a0: i1, %arg1: i1) -> (i8, i8) attributes {initialState = "A", argNames = ["a0", "a1"], resNames = ["r0", "r1"]} { %c_42 = hw.constant 42 : i8 fsm.state @A output { %c_0 = hw.constant 0 : i8 fsm.output %c_0, %c_42 : i8, i8 } transitions { fsm.transition @B } fsm.state @B output { %c_1 = hw.constant 1 : i8 fsm.output %c_1, %c_42 : i8, i8 } transitions { fsm.transition @A guard { %g = comb.and %a0, %arg1 : i1 fsm.return %g } } } ``` emits as ```sv typedef enum {A, B} top_state_t; module top( input a0, a1, clk, rst, output [7:0] r0, r1); reg [7:0] output_1; reg [7:0] output_0; top_state_t next_state; wire top_state_t to_A; wire top_state_t to_B; top_state_t state_reg; assign to_A = A; assign to_B = B; always_ff @(posedge clk) begin if (rst) state_reg <= to_A; else state_reg <= next_state; end always_comb begin case (state_reg) A: begin next_state = to_B; output_0 = 8'h0; output_1 = 8'h2A; end B: begin next_state = a0 & a1 ? to_A : to_B; output_0 = 8'h1; output_1 = 8'h2A; end endcase end assign r0 = output_0; assign r1 = output_1; endmodule ``` --- docs/Dialects/FSM/RationaleFSM.md | 2 +- include/circt/Conversion/FSMToSV.h | 22 + include/circt/Conversion/Passes.h | 1 + include/circt/Conversion/Passes.td | 11 + include/circt/Dialect/FSM/FSMOps.td | 20 +- integration_test/Dialect/FSM/driver.cpp | 53 ++ integration_test/Dialect/FSM/lit.local.cfg.py | 1 + integration_test/Dialect/FSM/top.mlir | 41 ++ lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/FSMToSV/CMakeLists.txt | 18 + lib/Conversion/FSMToSV/FSMToSV.cpp | 565 ++++++++++++++++++ lib/Conversion/PassDetail.h | 4 + lib/Dialect/FSM/FSMOps.cpp | 27 +- lib/Dialect/SV/SVOps.cpp | 2 +- test/Conversion/FSMToSV/test_basic.mlir | 91 +++ test/Conversion/FSMToSV/test_errors.mlir | 32 + test/Dialect/FSM/basics.mlir | 4 +- test/Dialect/FSM/errors.mlir | 15 + tools/circt-opt/CMakeLists.txt | 1 + 19 files changed, 898 insertions(+), 13 deletions(-) create mode 100644 include/circt/Conversion/FSMToSV.h create mode 100644 integration_test/Dialect/FSM/driver.cpp create mode 100644 integration_test/Dialect/FSM/lit.local.cfg.py create mode 100644 integration_test/Dialect/FSM/top.mlir create mode 100644 lib/Conversion/FSMToSV/CMakeLists.txt create mode 100644 lib/Conversion/FSMToSV/FSMToSV.cpp create mode 100644 test/Conversion/FSMToSV/test_basic.mlir create mode 100644 test/Conversion/FSMToSV/test_errors.mlir diff --git a/docs/Dialects/FSM/RationaleFSM.md b/docs/Dialects/FSM/RationaleFSM.md index d3facdba31..2e42730814 100644 --- a/docs/Dialects/FSM/RationaleFSM.md +++ b/docs/Dialects/FSM/RationaleFSM.md @@ -24,7 +24,7 @@ with the following features: internal variables of an FSM, allowing convenient analysis and transformation. 2. Provide a target-agnostic representation of FSM, allowing the state machine to be instantiated and attached to other dialects from different domains. -3. By cooperating with two conversion passes, FSMToHW and FSMToStandard, allow +3. By cooperating with two conversion passes, FSMToSV and FSMToStandard, allow to lower the FSM abstraction into HW+Comb+SV (Hardware) and Standard+SCF+MemRef (Software) dialects for the purposes of simulation, code generation, etc. diff --git a/include/circt/Conversion/FSMToSV.h b/include/circt/Conversion/FSMToSV.h new file mode 100644 index 0000000000..087c262d6b --- /dev/null +++ b/include/circt/Conversion/FSMToSV.h @@ -0,0 +1,22 @@ +//===- FSMToSV.h - FSM to SV conversions ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_CONVERSION_FSMTOSV_FSMTOSV_H +#define CIRCT_CONVERSION_FSMTOSV_FSMTOSV_H + +#include + +namespace mlir { +class Pass; +} // namespace mlir + +namespace circt { +std::unique_ptr createConvertFSMToSVPass(); +} // namespace circt + +#endif // CIRCT_CONVERSION_FSMTOSV_FSMTOSV_H diff --git a/include/circt/Conversion/Passes.h b/include/circt/Conversion/Passes.h index 6cfa1399d7..4039381869 100644 --- a/include/circt/Conversion/Passes.h +++ b/include/circt/Conversion/Passes.h @@ -17,6 +17,7 @@ #include "circt/Conversion/CalyxToHW.h" #include "circt/Conversion/ExportVerilog.h" #include "circt/Conversion/FIRRTLToHW.h" +#include "circt/Conversion/FSMToSV.h" #include "circt/Conversion/HWToLLHD.h" #include "circt/Conversion/HandshakeToFIRRTL.h" #include "circt/Conversion/HandshakeToHW.h" diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 5d1555c096..64b93561cb 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -147,6 +147,17 @@ def CalyxToHW : Pass<"lower-calyx-to-hw", "mlir::ModuleOp"> { "seq::SeqDialect", "sv::SVDialect"]; } +//===----------------------------------------------------------------------===// +// FSMToSV +//===----------------------------------------------------------------------===// + +def ConvertFSMToSV : Pass<"convert-fsm-to-sv", "mlir::ModuleOp"> { + let summary = "Convert FSM to HW"; + let constructor = "circt::createConvertFSMToSVPass()"; + let dependentDialects = ["circt::hw::HWDialect", "circt::comb::CombDialect", + "circt::seq::SeqDialect", "circt::sv::SVDialect"]; +} + //===----------------------------------------------------------------------===// // FIRRTLToHW //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/FSM/FSMOps.td b/include/circt/Dialect/FSM/FSMOps.td index 510428fc19..18da782d16 100644 --- a/include/circt/Dialect/FSM/FSMOps.td +++ b/include/circt/Dialect/FSM/FSMOps.td @@ -25,7 +25,9 @@ def MachineOp : FSMOp<"machine", [FunctionOpInterface, }]; let arguments = (ins StrAttr:$sym_name, StrAttr:$initialState, - TypeAttrOf:$function_type); + TypeAttrOf:$function_type, + OptionalAttr:$argNames, + OptionalAttr:$resNames); let regions = (region SizedRegion<1>:$body); let builders = [ @@ -47,6 +49,12 @@ def MachineOp : FSMOp<"machine", [FunctionOpInterface, return getFunctionTypeAttr().getValue().cast(); } + /// Returns the number of states in this machhine. + size_t getNumStates() { + auto stateOps = getBody().getOps(); + return std::distance(stateOps.begin(), stateOps.end()); + } + /// Returns the argument types of this function. ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } @@ -112,7 +120,7 @@ def TriggerOp : FSMOp<"trigger", []> { let hasVerifier = 1; } -def HWInstanceOp : FSMOp<"hw_instance", [Symbol, AttrSizedOperandSegments]> { +def HWInstanceOp : FSMOp<"hw_instance", [Symbol]> { let summary = "Create a hardware-style instance of a state machine"; let description = [{ `fsm.hw_instance` represents a hardware-style instance of a state machine, @@ -122,14 +130,12 @@ def HWInstanceOp : FSMOp<"hw_instance", [Symbol, AttrSizedOperandSegments]> { }]; let arguments = (ins StrAttr:$sym_name, FlatSymbolRefAttr:$machine, - Variadic:$inputs, Optional:$clock, - Optional:$reset); + Variadic:$inputs, I1:$clock, I1:$reset); let results = (outs Variadic:$outputs); let assemblyFormat = [{ - $sym_name $machine attr-dict `(` $inputs `)` `:` - functional-type($inputs, $outputs) (`,` `clock` $clock^ `:` qualified(type($clock)))? - (`,` `reset` $reset^ `:` qualified(type($reset)))? + $sym_name $machine attr-dict `(` $inputs `)` + `,` `clock` $clock `,` `reset` $reset `:` functional-type($inputs, $outputs) }]; let extraClassDeclaration = [{ diff --git a/integration_test/Dialect/FSM/driver.cpp b/integration_test/Dialect/FSM/driver.cpp new file mode 100644 index 0000000000..04dac17066 --- /dev/null +++ b/integration_test/Dialect/FSM/driver.cpp @@ -0,0 +1,53 @@ +#include "Vtop.h" +#include "verilated.h" +#include + +int main(int argc, char **argv) { + + Verilated::commandArgs(argc, argv); + auto *tb = new Vtop; + + // Post-reset start time. + int t0 = 2; + + for (int i = 0; i < 10; i++) { + if (i > t0) + std::cout << "out: " << char('A' + tb->out0) << std::endl; + + // Rising edge + tb->clk = 1; + tb->eval(); + + // Testbench + tb->rst = i < t0; + + // t0: Starts in A, + // t0+1: Default transition to B + + if (i == t0 + 2) { + // B -> C + tb->in0 = 1; + tb->in1 = 1; + } + + if (i == t0 + 3) { + // C -> B + tb->in0 = 0; + tb->in1 = 0; + } + + if (i == t0 + 4 || i == t0 + 5) { + // B -> C, C-> A + tb->in0 = 1; + tb->in1 = 1; + } + + // t0+6: Default transition to B + + // Falling edge + tb->clk = 0; + tb->eval(); + } + + exit(EXIT_SUCCESS); +} diff --git a/integration_test/Dialect/FSM/lit.local.cfg.py b/integration_test/Dialect/FSM/lit.local.cfg.py new file mode 100644 index 0000000000..1a3867bc01 --- /dev/null +++ b/integration_test/Dialect/FSM/lit.local.cfg.py @@ -0,0 +1 @@ +config.excludes.add('driver.cpp') diff --git a/integration_test/Dialect/FSM/top.mlir b/integration_test/Dialect/FSM/top.mlir new file mode 100644 index 0000000000..e8ea90d466 --- /dev/null +++ b/integration_test/Dialect/FSM/top.mlir @@ -0,0 +1,41 @@ +// REQUIRES: verilator +// RUN: circt-opt %s --convert-fsm-to-sv --canonicalize --lower-seq-to-sv --export-verilog -o %t2.mlir > %t1.sv +// RUN: circt-rtl-sim.py %t1.sv %S/driver.cpp | FileCheck %s +// CHECK: out: A +// CHECK: out: B +// CHECK: out: B +// CHECK: out: C +// CHECK: out: B +// CHECK: out: C +// CHECK: out: A + +fsm.machine @top(%arg0: i1, %arg1: i1) -> (i8) attributes {initialState = "A"} { + + fsm.state @A output { + %c_0 = hw.constant 0 : i8 + fsm.output %c_0 : i8 + } transitions { + fsm.transition @B + } + + fsm.state @B output { + %c_1 = hw.constant 1 : i8 + fsm.output %c_1 : i8 + } transitions { + fsm.transition @C guard { + %g = comb.and %arg0, %arg1 : i1 + fsm.return %g + } + } + + fsm.state @C output { + %c_2 = hw.constant 2 : i8 + fsm.output %c_2 : i8 + } transitions { + fsm.transition @A guard { + %g = comb.and %arg0, %arg1 : i1 + fsm.return %g + } + fsm.transition @B + } +} diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index bd5c530a1c..d4f86d8d15 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -11,3 +11,4 @@ add_subdirectory(SCFToCalyx) add_subdirectory(StaticLogicToCalyx) add_subdirectory(StandardToHandshake) add_subdirectory(StandardToStaticLogic) +add_subdirectory(FSMToSV) diff --git a/lib/Conversion/FSMToSV/CMakeLists.txt b/lib/Conversion/FSMToSV/CMakeLists.txt new file mode 100644 index 0000000000..afbe1577cc --- /dev/null +++ b/lib/Conversion/FSMToSV/CMakeLists.txt @@ -0,0 +1,18 @@ +add_circt_conversion_library(CIRCTFSMToSV + FSMToSV.cpp + + DEPENDS + CIRCTConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + CIRCTComb + CIRCTHW + CIRCTFSM + CIRCTSeq + CIRCTSV + CIRCTSupport + MLIRTransforms +) diff --git a/lib/Conversion/FSMToSV/FSMToSV.cpp b/lib/Conversion/FSMToSV/FSMToSV.cpp new file mode 100644 index 0000000000..430e65da4e --- /dev/null +++ b/lib/Conversion/FSMToSV/FSMToSV.cpp @@ -0,0 +1,565 @@ +//===- FSMToSV.cpp - Convert FSM to HW and SV Dialect ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Conversion/FSMToSV.h" +#include "../PassDetail.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/FSM/FSMOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/SV/SVOps.h" +#include "circt/Dialect/Seq/SeqOps.h" +#include "circt/Support/BackedgeBuilder.h" +#include "llvm/ADT/TypeSwitch.h" + +#include + +using namespace mlir; +using namespace circt; +using namespace fsm; + +/// Get the port info of a FSM machine. Clock and reset port are also added. +namespace { +struct ClkRstIdxs { + size_t clockIdx; + size_t resetIdx; +}; +} // namespace +static ClkRstIdxs getMachinePortInfo(SmallVectorImpl &ports, + MachineOp machine, OpBuilder &b) { + // Get the port info of the machine inputs and outputs. + machine.getHWPortInfo(ports); + ClkRstIdxs specialPorts; + + // Add clock port. + hw::PortInfo clock; + clock.name = b.getStringAttr("clk"); + clock.direction = hw::PortDirection::INPUT; + clock.type = b.getI1Type(); + clock.argNum = machine.getNumArguments(); + ports.push_back(clock); + specialPorts.clockIdx = clock.argNum; + + // Add reset port. + hw::PortInfo reset; + reset.name = b.getStringAttr("rst"); + reset.direction = hw::PortDirection::INPUT; + reset.type = b.getI1Type(); + reset.argNum = machine.getNumArguments() + 1; + ports.push_back(reset); + specialPorts.resetIdx = reset.argNum; + + return specialPorts; +} + +namespace { + +class StateEncoding { + // An class for handling state encoding. The class is designed to + // abstract away how states are selected in case patterns, referred to as + // values, and used as selection signals for muxes. + +public: + StateEncoding(OpBuilder &b, MachineOp machine, hw::HWModuleOp hwModule); + + // Get the encoded value for a state. + Value encode(StateOp state); + // Get the state corresponding to an encoded value. + StateOp decode(Value value); + + // Returns the type which encodes the state values. + Type getStateType() { return stateType; } + + // Returns a case pattern which matches the provided state. + std::unique_ptr getCasePattern(StateOp state); + +protected: + // Creates a constant value in the module for the given encoded state + // and records the state value in the mappings. An inner symbol is + // attached to the wire to avoid it being optimized away. + // The constant can optionally be assigned behind a sv wire - doing so at this + // point ensures that constants don't end up behind "_GEN#" wires in the + // module. + void setEncoding(StateOp state, Value v, bool wire = false); + + // A mapping between a StateOp and its corresponding encoded value. + SmallDenseMap stateToValue; + + // A mapping between an encoded value and its corresponding StateOp. + SmallDenseMap valueToState; + + // A mapping between an encoded value and the source value in the IR. + SmallDenseMap valueToSrcValue; + + // The enum type for the states. + Type stateType; + + OpBuilder &b; + MachineOp machine; + hw::HWModuleOp hwModule; +}; + +StateEncoding::StateEncoding(OpBuilder &b, MachineOp machine, + hw::HWModuleOp hwModule) + : b(b), machine(machine), hwModule(hwModule) { + Location loc = machine.getLoc(); + llvm::SmallVector stateNames; + + for (auto state : machine.getBody().getOps()) + stateNames.push_back(b.getStringAttr(state.getName())); + + // Create an enum typedef for the states. + Type rawEnumType = + hw::EnumType::get(b.getContext(), b.getArrayAttr(stateNames)); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(hwModule); + auto typeScope = b.create( + loc, b.getStringAttr(hwModule.getName() + "_enum_typedecls")); + typeScope.getBodyRegion().push_back(new Block()); + + b.setInsertionPointToStart(&typeScope.getBodyRegion().front()); + auto typedeclEnumType = b.create( + loc, b.getStringAttr(hwModule.getName() + "_state_t"), + TypeAttr::get(rawEnumType), nullptr); + + stateType = hw::TypeAliasType::get( + SymbolRefAttr::get(typeScope.sym_nameAttr(), + {FlatSymbolRefAttr::get(typedeclEnumType)}), + rawEnumType); + + // And create enum values for the states + b.setInsertionPointToStart(&hwModule.getBody().front()); + for (auto state : machine.getBody().getOps()) { + auto fieldAttr = hw::EnumFieldAttr::get( + loc, b.getStringAttr(state.getName()), stateType); + auto enumConstantOp = b.create( + loc, fieldAttr.getType().getValue(), fieldAttr); + setEncoding(state, enumConstantOp, + /*wire=*/true); + } +} + +// Get the encoded value for a state. +Value StateEncoding::encode(StateOp state) { + auto it = stateToValue.find(state); + assert(it != stateToValue.end() && "state not found"); + return it->second; +} +// Get the state corresponding to an encoded value. +StateOp StateEncoding::decode(Value value) { + auto it = valueToState.find(value); + assert(it != valueToState.end() && "encoded state not found"); + return it->second; +} + +// Returns a case pattern which matches the provided state. +std::unique_ptr StateEncoding::getCasePattern(StateOp state) { + // Get the field attribute for the state - fetch it through the encoding. + auto fieldAttr = + cast(valueToSrcValue[encode(state)].getDefiningOp()) + .fieldAttr(); + return std::make_unique(fieldAttr); +} + +void StateEncoding::setEncoding(StateOp state, Value v, bool wire) { + assert(stateToValue.find(state) == stateToValue.end() && + "state already encoded"); + + Value encodedValue; + if (wire) { + auto loc = machine.getLoc(); + auto stateType = getStateType(); + auto stateEncodingWire = b.create( + loc, stateType, b.getStringAttr("to_" + state.getName()), + /*inner_sym=*/state.getNameAttr()); + b.create(loc, stateEncodingWire, v); + encodedValue = b.create(loc, stateEncodingWire); + } else + encodedValue = v; + stateToValue[state] = encodedValue; + valueToState[encodedValue] = state; + valueToSrcValue[encodedValue] = v; +} + +class MachineOpConverter { +public: + MachineOpConverter(OpBuilder &builder, MachineOp machineOp) + : machineOp(machineOp), b(builder) {} + + // Converts the machine op to a hardware module. + // 1. Creates a HWModuleOp for the machine op, with the same I/O as the FSM + + // clk/reset ports. + // 2. Creates a state register + encodings for the states visible in the + // machine. + // 3. Iterates over all states in the machine + // 3.1. Moves all `comb` logic into the body of the HW module + // 3.2. Records the SSA value(s) associated to the output ports in the state + // 3.3. iterates of the transitions of the state + // 3.3.1. Moves all `comb` logic in the transition guard/action regions to + // the body of the HW module. + // 3.3.2. Creates a case pattern for the transition guard + // 3.4. Creates a next-state value for the state based on the transition + // guards. + // 4. Assigns next-state values for the states in a case statement on the + // state reg. + // 5. Assigns the current-state outputs for the states in a case statement + // on the state reg. + LogicalResult dispatch(); + +private: + struct StateConversionResult { + // Value of the next state output signal of the converted state. + Value nextState; + // Value of the output signals of the converted state. + llvm::SmallVector outputs; + }; + + using StateConversionResults = DenseMap; + + // Converts a StateOp within this machine, and returns the value corresponding + // to the next-state output of the op. + FailureOr convertState(StateOp state); + + // Converts the outgoing transitions of a state and returns the value + // corresponding to the next-state output of the op. + // Transitions are priority encoded in the order which they appear in the + // state transition region. + FailureOr convertTransitions(StateOp currentState, + ArrayRef transitions); + + // Returns the value that must be assigned to the hw output ports of this + // machine. + llvm::SmallVector + getOutputAssignments(StateConversionResults &stateConvResults); + + // Moves operations from 'block' into module scope, failing if any op were + // deemed illegal. Returns the final op in the block if the op was a + // terminator. An optional 'exclude' filer can be provided to dynamically + // exclude some ops from being moved. + FailureOr + moveOps(Block *block, + llvm::function_ref exclude = nullptr); + + // Build a SV case-based combinational mux the values provided in + // 'stateToValue' to a retured wire. + // 'stateToValue' being a list implies that multiple muxes can be emitted at + // once, avoiding bloating the IR with a case statement for every muxed value. + // A wire is returned for each srcMap provided. + // 'nameF' can be provided to specify the name of the output wire created for + // each source map. + llvm::SmallVector + buildStateCaseMux(Location loc, Value sel, + llvm::ArrayRef> srcMaps, + llvm::function_ref nameF = {}); + + // A handle to the state encoder for this machine. + std::unique_ptr encoding; + + // A deterministic ordering of the states in this machine. + llvm::SmallVector orderedStates; + + // A mapping from a state op to its next-state value. + llvm::SmallDenseMap nextStateFromState; + + // A handle to the MachineOp being converted. + MachineOp machineOp; + + // A handle to the HW ModuleOp being created. + hw::HWModuleOp hwModuleOp; + + // A handle to the state register of the machine. + seq::CompRegOp stateReg; + + OpBuilder &b; +}; + +FailureOr +MachineOpConverter::moveOps(Block *block, + llvm::function_ref exclude) { + for (auto &op : llvm::make_early_inc_range(*block)) { + if (!isa( + op.getDialect())) + return op.emitOpError() + << "is unsupported (op from the " + << op.getDialect()->getNamespace() << " dialect)."; + + if (exclude && exclude(&op)) + continue; + + if (op.hasTrait()) + return &op; + + op.moveBefore(&hwModuleOp.front(), b.getInsertionPoint()); + } + return nullptr; +} + +llvm::SmallVector MachineOpConverter::buildStateCaseMux( + Location loc, Value sel, + llvm::ArrayRef> srcMaps, + llvm::function_ref nameF) { + sv::CaseOp caseMux; + auto caseMuxCtor = [&]() { + caseMux = b.create(loc, CaseStmtType::CaseStmt, sel, + /*numCases=*/machineOp.getNumStates(), + [&](size_t caseIdx) { + StateOp state = orderedStates[caseIdx]; + return encoding->getCasePattern(state); + }); + }; + b.create(loc, caseMuxCtor); + + llvm::SmallVector dsts; + // note: cannot use llvm::enumerate, makes the underlying iterator const. + size_t idx = 0; + for (auto srcMap : srcMaps) { + auto valueType = srcMap.begin()->second.getType(); + StringAttr name; + if (nameF) + name = nameF(idx); + + auto dst = b.create(loc, valueType, name); + OpBuilder::InsertionGuard g(b); + for (auto [caseInfo, stateOp] : + llvm::zip(caseMux.getCases(), orderedStates)) { + b.setInsertionPointToEnd(caseInfo.block); + b.create(loc, dst, srcMap[stateOp]); + } + dsts.push_back(dst); + idx++; + } + return dsts; +} + +LogicalResult MachineOpConverter::dispatch() { + if (auto varOps = machineOp.front().getOps(); !varOps.empty()) + return (*varOps.begin())->emitOpError() + << "FSM variables not yet supported for SV " + "lowering."; + + b.setInsertionPoint(machineOp); + auto loc = machineOp.getLoc(); + if (machineOp.getNumStates() < 2) + return machineOp.emitOpError() << "expected at least 2 states."; + + // 1) Get the port info of the machine and create a new HW module for it. + SmallVector ports; + auto clkRstIdxs = getMachinePortInfo(ports, machineOp, b); + hwModuleOp = b.create(loc, machineOp.sym_nameAttr(), ports); + b.setInsertionPointToStart(&hwModuleOp.front()); + + // Replace all uses of the machine arguments with the arguments of the + // new created HW module. + for (auto args : + llvm::zip(machineOp.getArguments(), hwModuleOp.front().getArguments())) { + auto machineArg = std::get<0>(args); + auto hwModuleArg = std::get<1>(args); + machineArg.replaceAllUsesWith(hwModuleArg); + } + + auto clock = hwModuleOp.front().getArgument(clkRstIdxs.clockIdx); + auto reset = hwModuleOp.front().getArgument(clkRstIdxs.resetIdx); + + // 2) Build state register. + encoding = std::make_unique(b, machineOp, hwModuleOp); + auto stateType = encoding->getStateType(); + + BackedgeBuilder bb(b, loc); + auto nextStateBackedge = bb.get(stateType); + stateReg = b.create( + loc, stateType, nextStateBackedge, clock, "state_reg", reset, + /*reset value=*/encoding->encode(machineOp.getInitialStateOp()), nullptr); + + // Move any operations at the machine-level scope, excluding state ops, which + // are handled separately. + if (failed(moveOps(&machineOp.front(), + [](Operation *op) { return isa(op); }))) { + bb.abandon(); + return failure(); + } + + // 3) Convert states and record their next-state value. + StateConversionResults stateConvResults; + for (auto state : machineOp.getBody().getOps()) { + auto stateConvRes = convertState(state); + if (failed(stateConvRes)) { + bb.abandon(); + return failure(); + } + stateConvResults[state] = stateConvRes.getValue(); + orderedStates.push_back(state); + nextStateFromState[state] = stateConvRes.getValue().nextState; + } + + // 4/5) Create next-state maps for each output and the next-state signal in a + // format suitable for creating a case mux. + llvm::SmallVector, 4> nextStateMaps; + nextStateMaps.push_back(nextStateFromState); + for (size_t portIndex = 0; portIndex < machineOp.getNumResults(); + portIndex++) { + auto &nsmap = nextStateMaps.emplace_back(); + for (auto &state : orderedStates) + nsmap[state] = stateConvResults[state].outputs[portIndex]; + } + + // Materialize the case mux. We do this in a single call to have a single + // always_comb block. + auto stateCaseMuxes = buildStateCaseMux( + machineOp.getLoc(), stateReg, nextStateMaps, [&](size_t idx) { + if (idx == 0) + return b.getStringAttr("next_state"); + + return b.getStringAttr("output_" + std::to_string(idx - 1)); + }); + + nextStateBackedge.setValue(b.create(loc, stateCaseMuxes[0])); + + llvm::SmallVector outputPortAssignments; + for (auto outputMux : llvm::makeArrayRef(stateCaseMuxes).drop_front()) + outputPortAssignments.push_back( + b.create(machineOp.getLoc(), outputMux)); + + // Delete the default created output op and replace it with the output muxes. + auto *oldOutputOp = hwModuleOp.front().getTerminator(); + b.create(loc, outputPortAssignments); + oldOutputOp->erase(); + + // Erase the original machine op. + machineOp.erase(); + + return success(); +} + +FailureOr +MachineOpConverter::convertTransitions( // NOLINT(misc-no-recursion) + StateOp currentState, ArrayRef transitions) { + Value nextState; + if (transitions.empty()) { + // Base case - transition to the current state. + nextState = encoding->encode(currentState); + } else { + // Recursive case - transition to a named state. + auto transition = cast(transitions.front()); + nextState = encoding->encode(transition.getNextState()); + if (transition.hasGuard()) { + // Not always taken; recurse and mux between the targeted next state and + // the recursion result, selecting based on the provided guard. + auto guardOpRes = moveOps(&transition.guard().front()); + if (failed(guardOpRes)) + return failure(); + + auto guard = cast(*guardOpRes).getOperand(0); + auto otherNextState = + convertTransitions(currentState, transitions.drop_front()); + if (failed(otherNextState)) + return failure(); + comb::MuxOp nextStateMux = b.create( + transition.getLoc(), guard, nextState, *otherNextState); + nextState = nextStateMux; + } + } + + assert(nextState && "next state should be defined"); + return nextState; +} + +llvm::SmallVector +MachineOpConverter::getOutputAssignments(StateConversionResults &convResults) { + + // One for each output port. + llvm::SmallVector> outputPortValues( + machineOp.getNumResults()); + for (auto &state : orderedStates) { + for (size_t portIndex = 0; portIndex < machineOp.getNumResults(); + portIndex++) + outputPortValues[portIndex][state] = + convResults[state].outputs[portIndex]; + } + + llvm::SmallVector outputPortAssignments; + + auto outputMuxes = buildStateCaseMux( + machineOp.getLoc(), stateReg, outputPortValues, [&](size_t idx) { + return b.getStringAttr("output_" + std::to_string(idx)); + }); + + for (auto outputMux : outputMuxes) + outputPortAssignments.push_back( + b.create(machineOp.getLoc(), outputMux)); + + return outputPortAssignments; +} + +FailureOr +MachineOpConverter::convertState(StateOp state) { + MachineOpConverter::StateConversionResult res; + + // 3.1) Convert the output region by moving the operations into the module + // scope and gathering the operands of the output op. + auto outputOpRes = moveOps(&state.output().front()); + if (failed(outputOpRes)) + return failure(); + + OutputOp outputOp = cast(outputOpRes.getValue()); + res.outputs = outputOp.getOperands(); // 3.2 + + auto transitions = llvm::SmallVector( + state.transitions().getOps()); + // 3.3, 3.4) Convert the transitions and record the next-state value + // derived from the transitions being selected in a priority-encoded manner. + auto nextStateRes = convertTransitions(state, transitions); + if (failed(nextStateRes)) + return failure(); + res.nextState = nextStateRes.getValue(); + return res; +} + +struct FSMToSVPass : public ConvertFSMToSVBase { + void runOnOperation() override; +}; + +void FSMToSVPass::runOnOperation() { + auto module = getOperation(); + auto b = OpBuilder(module); + SmallVector opToErase; + + // Traverse all machines and convert. + for (auto machine : llvm::make_early_inc_range(module.getOps())) { + MachineOpConverter converter(b, machine); + + if (failed(converter.dispatch())) { + signalPassFailure(); + return; + } + } + + // Traverse all machine instances and convert to hw instances. + llvm::SmallVector instances; + module.walk([&](HWInstanceOp instance) { instances.push_back(instance); }); + for (auto instance : instances) { + auto fsmHWModule = module.lookupSymbol(instance.machine()); + assert(fsmHWModule && + "FSM machine should have been converted to a hw.module"); + + b.setInsertionPoint(instance); + llvm::SmallVector operands; + llvm::transform(instance.getOperands(), std::back_inserter(operands), + [&](auto operand) { return operand; }); + auto hwInstance = b.create( + instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()), + operands, nullptr); + instance.replaceAllUsesWith(hwInstance); + instance.erase(); + } +} + +} // end anonymous namespace + +std::unique_ptr circt::createConvertFSMToSVPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 9a4fb08471..6a1b419eff 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -90,6 +90,10 @@ namespace sv { class SVDialect; } // namespace sv +namespace fsm { +class FSMDialect; +} // namespace fsm + // Generate the classes which represent the passes #define GEN_PASS_CLASSES #include "circt/Conversion/Passes.h.inc" diff --git a/lib/Dialect/FSM/FSMOps.cpp b/lib/Dialect/FSM/FSMOps.cpp index 40361ce329..b91c1da055 100644 --- a/lib/Dialect/FSM/FSMOps.cpp +++ b/lib/Dialect/FSM/FSMOps.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace circt; @@ -59,7 +60,10 @@ void MachineOp::getHWPortInfo(SmallVectorImpl &ports) { for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) { hw::PortInfo port; - port.name = builder.getStringAttr("in" + std::to_string(i)); + if (argNames()) + port.name = argNames().getValue()[i].cast(); + else + port.name = builder.getStringAttr("in" + std::to_string(i)); port.direction = circt::hw::PortDirection::INPUT; port.type = machineType.getInput(i); port.argNum = i; @@ -68,7 +72,10 @@ void MachineOp::getHWPortInfo(SmallVectorImpl &ports) { for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) { hw::PortInfo port; - port.name = builder.getStringAttr("out" + std::to_string(i)); + if (resNames()) + port.name = resNames().getValue()[i].cast(); + else + port.name = builder.getStringAttr("out" + std::to_string(i)); port.direction = circt::hw::PortDirection::OUTPUT; port.type = machineType.getResult(i); port.argNum = i; @@ -125,6 +132,22 @@ LogicalResult MachineOp::verify() { return emitOpError("initial state '" + initialState() + "' was not defined in the machine"); + if (argNames() && argNames().getValue().size() != getArgumentTypes().size()) + return emitOpError() << "number of machine arguments (" + << getArgumentTypes().size() + << ") does " + "not match the provided number " + "of argument names (" + << argNames().getValue().size() << ")"; + + if (resNames() && resNames().getValue().size() != getResultTypes().size()) + return emitOpError() << "number of machine results (" + << getResultTypes().size() + << ") does " + "not match the provided number " + "of result names (" + << resNames().getValue().size() << ")"; + return success(); } diff --git a/lib/Dialect/SV/SVOps.cpp b/lib/Dialect/SV/SVOps.cpp index 86a39a0228..6d545ade23 100644 --- a/lib/Dialect/SV/SVOps.cpp +++ b/lib/Dialect/SV/SVOps.cpp @@ -885,7 +885,7 @@ void CaseOp::print(OpAsmPrinter &p) { LogicalResult CaseOp::verify() { if (!(hw::isHWIntegerType(getCond().getType()) || - getCond().getType().isa())) + hw::isHWEnumType(getCond().getType()))) return emitError("condition must have either integer or enum type"); // Ensure that the number of regions and number of case values match. diff --git a/test/Conversion/FSMToSV/test_basic.mlir b/test/Conversion/FSMToSV/test_basic.mlir new file mode 100644 index 0000000000..6f87c3a7ca --- /dev/null +++ b/test/Conversion/FSMToSV/test_basic.mlir @@ -0,0 +1,91 @@ +// RUN: circt-opt -split-input-file -convert-fsm-to-sv %s | FileCheck %s + +fsm.machine @FSM(%arg0: i1, %arg1: i1) -> (i8) attributes {initialState = "A"} { + %c_0 = hw.constant 0 : i8 + fsm.state @A output { + fsm.output %c_0 : i8 + } transitions { + fsm.transition @B + } + + fsm.state @B output { + fsm.output %c_0 : i8 + } transitions { + fsm.transition @A + } +} + +// ---- + +// CHECK: hw.module @top(%arg0: i1, %arg1: i1, %clk: i1, %rst: i1) -> (out: i8) { +// CHECK: %fsm_inst.out0 = hw.instance "fsm_inst" @FSM(in0: %arg0: i1, in1: %arg1: i1, clk: %clk: i1, rst: %rst: i1) -> (out0: i8) +// CHECK: hw.output %fsm_inst.out0 : i8 +// CHECK: } +hw.module @top(%arg0: i1, %arg1: i1, %clk : i1, %rst : i1) -> (out: i8) { + %out = fsm.hw_instance "fsm_inst" @FSM(%arg0, %arg1), clock %clk, reset %rst : (i1, i1) -> (i8) + hw.output %out : i8 +} + +// ----- + +// CHECK-LABEL: hw.type_scope @top_enum_typedecls { +// CHECK-NEXT: hw.typedecl @top_state_t : !hw.enum +// CHECK-NEXT: } + +// CHECK-LABEL: hw.module @top(%a0: i1, %a1: i1, %clk: i1, %rst: i1) -> (r0: i8, r1: i8) { +// CHECK-NEXT: %A = hw.enum.constant A : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: %to_A = sv.wire sym @A : !hw.inout>> +// CHECK-NEXT: sv.assign %to_A, %A : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: %0 = sv.read_inout %to_A : !hw.inout>> +// CHECK-NEXT: %B = hw.enum.constant B : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: %to_B = sv.wire sym @B : !hw.inout>> +// CHECK-NEXT: sv.assign %to_B, %B : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: %1 = sv.read_inout %to_B : !hw.inout>> +// CHECK-NEXT: %state_reg = seq.compreg %4, %clk, %rst, %0 : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: %c42_i8 = hw.constant 42 : i8 +// CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 +// CHECK-NEXT: %c1_i8 = hw.constant 1 : i8 +// CHECK-NEXT: %2 = comb.and %a0, %a1 : i1 +// CHECK-NEXT: %3 = comb.mux %2, %0, %1 : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: sv.alwayscomb { +// CHECK-NEXT: sv.case %state_reg : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: case A: { +// CHECK-NEXT: sv.bpassign %next_state, %1 : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: sv.bpassign %output_0, %c0_i8 : i8 +// CHECK-NEXT: sv.bpassign %output_1, %c42_i8 : i8 +// CHECK-NEXT: } +// CHECK-NEXT: case B: { +// CHECK-NEXT: sv.bpassign %next_state, %3 : !hw.typealias<@top_enum_typedecls::@top_state_t, !hw.enum> +// CHECK-NEXT: sv.bpassign %output_0, %c1_i8 : i8 +// CHECK-NEXT: sv.bpassign %output_1, %c42_i8 : i8 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %next_state = sv.reg : !hw.inout>> +// CHECK-NEXT: %output_0 = sv.reg : !hw.inout +// CHECK-NEXT: %output_1 = sv.reg : !hw.inout +// CHECK-NEXT: %4 = sv.read_inout %next_state : !hw.inout>> +// CHECK-NEXT: %5 = sv.read_inout %output_0 : !hw.inout +// CHECK-NEXT: %6 = sv.read_inout %output_1 : !hw.inout +// CHECK-NEXT: hw.output %5, %6 : i8, i8 +// CHECK-NEXT: } + + +fsm.machine @top(%a0: i1, %arg1: i1) -> (i8, i8) attributes {initialState = "A", argNames = ["a0", "a1"], resNames = ["r0", "r1"]} { + %c_42 = hw.constant 42 : i8 + fsm.state @A output { + %c_0 = hw.constant 0 : i8 + fsm.output %c_0, %c_42 : i8, i8 + } transitions { + fsm.transition @B + } + + fsm.state @B output { + %c_1 = hw.constant 1 : i8 + fsm.output %c_1, %c_42 : i8, i8 + } transitions { + fsm.transition @A guard { + %g = comb.and %a0, %arg1 : i1 + fsm.return %g + } + } +} diff --git a/test/Conversion/FSMToSV/test_errors.mlir b/test/Conversion/FSMToSV/test_errors.mlir new file mode 100644 index 0000000000..868379a6f3 --- /dev/null +++ b/test/Conversion/FSMToSV/test_errors.mlir @@ -0,0 +1,32 @@ +// RUN: circt-opt -split-input-file -convert-fsm-to-sv -verify-diagnostics %s + + +fsm.machine @foo(%arg0: i1) -> () attributes {initialState = "A"} { + // expected-error@+1 {{'fsm.variable' op FSM variables not yet supported for SV lowering.}} + %cnt = fsm.variable "cnt" {initValue = 0 : i16} : i16 + + fsm.state @A output { + fsm.output + } transitions { + fsm.transition @A + } + +} + +// ----- + +fsm.machine @foo(%arg0: i1) -> (i1) attributes {initialState = "A"} { + // expected-error@+1 {{'arith.constant' op is unsupported (op from the arith dialect).}} + %true = arith.constant true + fsm.state @A output { + fsm.output %true : i1 + } transitions { + fsm.transition @A + } + + fsm.state @B output { + fsm.output %true : i1 + } transitions { + fsm.transition @A + } +} diff --git a/test/Dialect/FSM/basics.mlir b/test/Dialect/FSM/basics.mlir index 39681ef012..6fed85f0f7 100644 --- a/test/Dialect/FSM/basics.mlir +++ b/test/Dialect/FSM/basics.mlir @@ -35,7 +35,7 @@ // CHECK: } // CHECK: hw.module @bar(%clk: i1, %rst_n: i1) { // CHECK: %true = hw.constant true -// CHECK: %0 = fsm.hw_instance "foo_inst" @foo(%true) : (i1) -> i1, clock %clk : i1, reset %rst_n : i1 +// CHECK: %0 = fsm.hw_instance "foo_inst" @foo(%true), clock %clk, reset %rst_n : (i1) -> i1 // CHECK: hw.output // CHECK: } // CHECK: func @qux() { @@ -91,7 +91,7 @@ fsm.machine @foo(%arg0: i1) -> i1 attributes {initialState = "IDLE"} { // Hardware-style instantiation. hw.module @bar(%clk: i1, %rst_n: i1) { %in = hw.constant true - %out = fsm.hw_instance "foo_inst" @foo(%in) : (i1) -> i1, clock %clk : i1, reset %rst_n : i1 + %out = fsm.hw_instance "foo_inst" @foo(%in), clock %clk, reset %rst_n : (i1) -> i1 } // Software-style instantiation and triggering. diff --git a/test/Dialect/FSM/errors.mlir b/test/Dialect/FSM/errors.mlir index b16f0ae3d0..252c70a0a7 100644 --- a/test/Dialect/FSM/errors.mlir +++ b/test/Dialect/FSM/errors.mlir @@ -66,3 +66,18 @@ fsm.machine @foo(%arg0: i1) -> i1 attributes {initialState = "IDLE"} { } } } + +// ----- + +// expected-error @+1 {{'fsm.machine' op number of machine arguments (1) does not match the provided number of argument names (2)}} +fsm.machine @foo(%arg0: i1) -> i1 attributes {initialState = "IDLE", argNames = ["in0", "in1"]} { + fsm.state @IDLE output {} transitions {} +} + +// ----- + +// expected-error @+1 {{'fsm.machine' op number of machine results (1) does not match the provided number of result names (2)}} +fsm.machine @foo(%arg0: i1) -> i1 attributes {initialState = "IDLE", resNames = ["out0", "out1"]} { + fsm.state @IDLE output {} transitions {} +} + diff --git a/tools/circt-opt/CMakeLists.txt b/tools/circt-opt/CMakeLists.txt index 5ebe955a1c..3cf7a0a185 100644 --- a/tools/circt-opt/CMakeLists.txt +++ b/tools/circt-opt/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries(circt-opt CIRCTFIRRTLTransforms CIRCTFSM CIRCTFSMTransforms + CIRCTFSMToSV CIRCTHandshake CIRCTHandshakeToFIRRTL CIRCTHandshakeToHW