[gg] Implement multi-clock compatible wiring

This commit is contained in:
David Biancolin 2020-03-11 03:58:25 +00:00
parent ab9b0c22e5
commit 320f6360f1
5 changed files with 257 additions and 18 deletions

View File

@ -14,6 +14,7 @@ import firrtl.ir._
import logger._
import firrtl.Mappers._
import firrtl.transforms.{DedupModules, DeadCodeElimination}
import firrtl.passes.wiring.{WiringTransform}
import Utils._
import java.io.{File, FileWriter}
@ -55,6 +56,13 @@ private[midas] class MidasTransforms(implicit p: Parameters) extends Transform {
new AssertPass(dir),
new PrintSynthesis(dir),
new ResolveAndCheck,
TriggerWiring,
new EmitFirrtl("post-trigger-wiring.fir"),
new fame.EmitFAMEAnnotations("post-trigger-wiring.json"),
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new WiringTransform,
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
new EmitFirrtl("post-debug-synthesis.fir"),

View File

@ -0,0 +1,193 @@
//See LICENSE for license details.
package midas.passes
import midas.targetutils.{TriggerSourceAnnotation, TriggerSinkAnnotation}
import midas.passes.fame.{FAMEChannelConnectionAnnotation, TargetClockChannel, Neq, RegZeroPreset}
import freechips.rocketchip.util.DensePrefixSum
import firrtl._
import firrtl.annotations._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.passes.wiring.{SinkAnnotation, SourceAnnotation, WiringTransform}
import firrtl.Utils.{zero, BoolType}
import scala.collection.mutable
private[passes] object TriggerWiring extends firrtl.Transform {
def inputForm = LowForm
def outputForm = HighForm
override def name = "[Golden Gate] Trigger Wiring"
val topWiringPrefix = "simulationTrigger_"
val localCType = UIntType(IntWidth(16))
val globalCType = UIntType(IntWidth(32))
val sinkWiringKey = "trigger_sink"
def onModuleSink(sinkAnnoModuleMap: Map[String, Seq[TriggerSinkAnnotation]],
addedAnnos: mutable.ArrayBuffer[Annotation])
(m: DefModule): DefModule = m match {
case m: Module if sinkAnnoModuleMap.isDefinedAt(m.name) =>
val sinkNameMap = sinkAnnoModuleMap(m.name).map(anno => anno.target.ref -> anno).toMap
val ns = Namespace(m)
m.map(onStmtSink(sinkNameMap, addedAnnos, ns))
case o => o
}
def onStmtSink(sinkAnnos: Map[String, TriggerSinkAnnotation],
addedAnnos: mutable.ArrayBuffer[Annotation],
ns: Namespace)
(s: Statement): Statement = s.map(onStmtSink(sinkAnnos, addedAnnos, ns)) match {
case node@DefNode(_,name,_) if sinkAnnos.isDefinedAt(name) =>
val sinkAnno = sinkAnnos(name)
val mT = sinkAnno.enclosingModuleTarget
val triggerSyncName = ns.newName("trigger_sync")
val triggerSync = RegZeroPreset(NoInfo, triggerSyncName, BoolType, WRef(sinkAnno.clock.ref))
addedAnnos += SinkAnnotation(mT.ref(triggerSyncName).toNamed, sinkWiringKey)
Block(triggerSync, node.copy(value = WRef(triggerSync)))
// Implement the missing cases?
case r@DefRegister(_,name,_,_,_,_) if sinkAnnos.isDefinedAt(name) => ???
case s => s
}
def execute(state: CircuitState): CircuitState = {
val topModName = state.circuit.main
val topMod = state.circuit.modules.find(_.name == topModName).get
val prexistingPorts = topMod.ports
// 1) Collect Trigger Annotations, and generate BridgeTopWiring annotations
val srcCreditAnnos = new mutable.ArrayBuffer[TriggerSourceAnnotation]()
val srcDebitAnnos = new mutable.ArrayBuffer[TriggerSourceAnnotation]()
val sinkAnnos = new mutable.ArrayBuffer[TriggerSinkAnnotation]()
state.annotations.collect({
case a: TriggerSourceAnnotation if a.sourceType => srcCreditAnnos += a
case a: TriggerSourceAnnotation => srcDebitAnnos += a
case a: TriggerSinkAnnotation => sinkAnnos += a
})
require(!(srcCreditAnnos.isEmpty && srcDebitAnnos.nonEmpty), "Provided trigger debit sources but no credit sources")
// It may make sense to relax this in the future
// This would enable trigger without posibility of disabling it in the future
require(!(srcDebitAnnos.isEmpty && srcCreditAnnos.nonEmpty), "Provided trigger credit sources but no debit sources")
val updatedState = if (srcCreditAnnos.isEmpty && srcDebitAnnos.isEmpty || sinkAnnos.isEmpty) {
state
} else {
val bridgeTopWiringAnnos = (srcCreditAnnos ++ srcDebitAnnos).map(anno =>
BridgeTopWiringAnnotation(anno.target, anno.clock))
// 2) Use bridge topWiring to generate inter-module connectivity -- but drop the port list
val wiredState = (new BridgeTopWiring(topWiringPrefix)).execute(state.copy(
annotations = state.annotations ++ bridgeTopWiringAnnos))
val wiredTopModule = wiredState.circuit.modules.collectFirst({
case m@Module(_,name,_,_) if name == topModName => m
}).get
val otherModules = wiredState.circuit.modules.filter(_.name != topModName)
val addedPorts = wiredTopModule.ports.filterNot(p => prexistingPorts.contains(p))
// Step 3: Group top-wired outputs by their associated clock
val outputAnnos = wiredState.annotations.collect({ case a: BridgeTopWiringOutputAnnotation => a })
val groupedTriggers = outputAnnos.groupBy(_.clockPort)
// Step 4: Convert port assignments to wire assignments
val portName2WireMap = addedPorts.map(p => p.name -> DefWire(NoInfo, p.name, p.tpe)).toMap
def updateAssignments(stmt: Statement): Statement = stmt.map(updateAssignments) match {
case c@Connect(_, WRef(name,_,_,_), _) if portName2WireMap.isDefinedAt(name) =>
val defWire = portName2WireMap(name)
Block(defWire, c.copy(loc = WRef(defWire)))
case c@Connect(_, _, WRef(name,_,_,_)) if portName2WireMap.isDefinedAt(name) =>
val defWire = portName2WireMap(name)
Block(defWire, c.copy(expr = WRef(defWire)))
case o => o
}
val portRemovedBody = wiredTopModule.body.map(updateAssignments)
// 5) Per-clock-domain: generate clock-domain popcount
val ns = Namespace(wiredTopModule)
val addedStmts = new mutable.ArrayBuffer[Statement]()
def popCount(bools: Seq[WRef]): WRef = DensePrefixSum(bools)({ case (a, b) =>
val name = ns.newTemp
val node = DefNode(NoInfo, name, DoPrim(PrimOps.Add, Seq(a, b), Seq.empty, UnknownType))
addedStmts += node
WRef(node)
}).last
def counter(name: String, tpe: UIntType, clock: WRef, incr: WRef): (DefRegister, DefNode) = {
val countName = ns.newName(name)
val count = RegZeroPreset(NoInfo, countName, tpe, clock)
val nextName = ns.newName(countName + "_next")
val next = DefNode(NoInfo, nextName, DoPrim(PrimOps.Add, Seq(WRef(count), incr), Seq.empty, tpe))
val countUpdate = Connect(NoInfo, WRef(count), WRef(next))
addedStmts ++= Seq(count, next, countUpdate)
(count, next)
}
def doAccounting(counterType: UIntType, clock: WRef)(name: String, bools: Seq[WRef]): WRef =
WRef(counter(name, counterType, clock, popCount(bools))._2)
val (localCredits, localDebits) = (for ((clockRT, oAnnos) <- groupedTriggers) yield {
val credits = oAnnos.collect {
case a if srcCreditAnnos.exists(_.target == a.pathlessSource) => WRef(portName2WireMap(a.topSink.ref))
}
val debits = oAnnos.collect {
case a if srcDebitAnnos.exists(_.target == a.pathlessSource) => WRef(portName2WireMap(a.topSink.ref))
}
def doLocalAccounting = doAccounting(localCType, WRef(clockRT.ref)) _
val domainName = clockRT.ref
(doLocalAccounting(s"${domainName}_credits", credits), doLocalAccounting(s"${domainName}_debits", debits))
}).unzip
// 6) Synchronize and aggregate counts in reference domain
val refClockRT = wiredState.annotations.collectFirst({
case FAMEChannelConnectionAnnotation(_,TargetClockChannel,_,_,Some(clock :: _)) => clock
}).get
def syncAndDiff(next: WRef): WRef = {
val name = next.name
val syncNameS1 = ns.newName(s"${name}_count_sync_s1")
val syncS1 = RegZeroPreset(NoInfo, syncNameS1, localCType, WRef(refClockRT.ref))
val syncNameS2 = ns.newName(s"${name}_count_sync_s2")
val syncS2 = RegZeroPreset(NoInfo, syncNameS2, localCType, WRef(refClockRT.ref))
val diffName = ns.newName(s"${name}_diff")
val diffNode = DefNode(NoInfo, diffName, DoPrim(PrimOps.Sub, Seq(WRef(syncS1), WRef(syncS2)), Seq.empty, localCType))
addedStmts ++= Seq(
syncS1, syncS2, diffNode,
Connect(NoInfo, WRef(syncS1), next),
Connect(NoInfo, WRef(syncS2), WRef(syncS1))
)
WRef(diffNode)
}
val creditUpdates = localCredits.map(syncAndDiff).toSeq
val debitUpdates = localDebits.map(syncAndDiff).toSeq
def doGlobalAccounting = doAccounting(globalCType, WRef(refClockRT.ref)) _
val totalCredit = doGlobalAccounting("totalCredits", creditUpdates)
val totalDebit = doGlobalAccounting("totalDebits", debitUpdates)
val triggerName = ns.newName("trigger_source")
val triggerSource = DefNode(NoInfo, triggerName, Neq(totalCredit, totalDebit))
val triggerSourceRT = ModuleTarget(topModName, topModName).ref(triggerName)
addedStmts += triggerSource
val topModWithTrigger = wiredTopModule.copy(ports = prexistingPorts, body = Block(portRemovedBody, addedStmts:_*))
val updatedCircuit = wiredState.circuit.copy(modules = topModWithTrigger +: otherModules)
// Step 7) Wire generated trigger to all sinks
val sinkModuleMap = sinkAnnos.groupBy(_.target.module)
val wiringAnnos = new mutable.ArrayBuffer[Annotation]
wiringAnnos += SourceAnnotation(triggerSourceRT.toNamed, sinkWiringKey)
val preSinkWiringCircuit = updatedCircuit.map(onModuleSink(sinkModuleMap, wiringAnnos))
CircuitState(preSinkWiringCircuit, HighForm, wiredState.annotations ++ wiringAnnos)
}
val cleanedAnnos = updatedState.annotations.flatMap({
case a: TriggerSourceAnnotation => None
case a: TriggerSinkAnnotation => None
case a: BridgeTopWiringOutputAnnotation => None
case o => Some(o)
})
updatedState.copy(annotations = cleanedAnnos)
}
}

View File

@ -242,6 +242,7 @@ case class TriggerSinkAnnotation(
val renamedClock = renamer.exactRename(clock)
Seq(this.copy(target = renamedTarget, clock = renamedClock))
}
def enclosingModuleTarget(): ModuleTarget = ModuleTarget(target.circuit, target.module)
}
object TriggerSource {

View File

@ -6,27 +6,44 @@
class TriggerWiringModule_t: virtual simif_t
{
public:
synthesized_assertions_t * assert_endpoint;
std::vector<synthesized_assertions_t *> assert_endpoints;
TriggerWiringModule_t(int argc, char** argv) {
ASSERTBRIDGEMODULE_0_substruct_create;
assert_endpoint = new synthesized_assertions_t(this,
ASSERTBRIDGEMODULE_1_substruct_create;
assert_endpoints.push_back(new synthesized_assertions_t(this,
ASSERTBRIDGEMODULE_0_substruct,
ASSERTBRIDGEMODULE_0_assert_count,
ASSERTBRIDGEMODULE_0_assert_messages);
ASSERTBRIDGEMODULE_0_assert_messages));
assert_endpoints.push_back(new synthesized_assertions_t(this,
ASSERTBRIDGEMODULE_1_substruct,
ASSERTBRIDGEMODULE_1_assert_count,
ASSERTBRIDGEMODULE_1_assert_messages));
};
bool simulation_complete() {
bool is_complete = false;
for (auto &e: assert_endpoints) {
is_complete |= e->terminate();
}
return is_complete;
}
int exit_code(){
for (auto &e: assert_endpoints) {
if (e->exit_code())
return e->exit_code();
}
return 0;
}
void run() {
int assertions_thrown = 0;
poke(reset, 1);
step(1);
poke(reset, 0);
step(10000, false);
while (!done()) {
assert_endpoint->tick();
if (assert_endpoint->terminate()) {
assert_endpoint->resume();
assertions_thrown++;
while (!done() && !simulation_complete()) {
for (auto ep: assert_endpoints) {
ep->tick();
}
}
expect(assertions_thrown == 0, "No assertions should be thrown");
expect(!exit_code(), "No assertions should be thrown");
}
};

View File

@ -4,9 +4,10 @@ package firesim.midasexamples
import midas.widgets.{RationalClockBridge, PeekPokeBridge}
import midas.targetutils.{TriggerSource, TriggerSink}
import freechips.rocketchip.util.DensePrefixSum
import freechips.rocketchip.util.{DensePrefixSum, ResetCatchAndSync}
import chisel3._
import chisel3.util._
import chisel3.experimental.chiselName
import scala.collection.mutable
@ -29,10 +30,10 @@ trait SourceCredit { self: MultiIOModule =>
trait SourceDebit { self: MultiIOModule =>
val referenceDebit = IO(Output(Bool()))
private val lfsr = LFSR16()
referenceDebit := lfsr(0) ^ lfsr(1)
referenceDebit := ShiftRegister(lfsr(0), 5)
val debit = Wire(Bool())
debit := referenceDebit
TriggerSource.credit(debit)
TriggerSource.debit(debit)
}
class TriggerSourceModule extends MultiIOModule with SourceCredit with SourceDebit
@ -56,6 +57,7 @@ class ReferenceSourceCounters(numCredits: Int, numDebits: Int) extends MultiIOMo
totalCredit := doAccounting(inputCredits)
totalDebit := doAccounting(inputDebits)
@chiselName
def synchAndDiff(count: UInt): UInt = {
val sync = RegNext(count)
val syncLast = RegNext(sync)
@ -79,10 +81,11 @@ object ReferenceSourceCounters {
// as seen by all nodes with a trigger sink, fail to match their references.
class TriggerWiringModule extends RawModule {
val clockBridge = Module(new RationalClockBridge(1000, (1,2)))
val refClock :: div2Clock = clockBridge.io.clocks.toList
val refClock :: div2Clock :: _ = clockBridge.io.clocks.toList
val refSourceCounts = new mutable.ArrayBuffer[ReferenceSourceCounters]()
val refSinks = new mutable.ArrayBuffer[Bool]()
val reset = WireInit(false.B)
val resetHalfRate = ResetCatchAndSync(div2Clock, reset.toBool)
withClockAndReset(refClock, reset) {
val peekPokeBridge = PeekPokeBridge(refClock, reset)
val src = Module(new TriggerSourceModule)
@ -97,13 +100,30 @@ class TriggerWiringModule extends RawModule {
}
}
// Reference Trigger Enable
withClockAndReset(refClock, reset) {
withClockAndReset(div2Clock, resetHalfRate) {
val src = Module(new TriggerSourceModule)
val sink = Module(new TriggerSinkModule)
// Reference Hardware
refSourceCounts += ReferenceSourceCounters(Seq(src.referenceCredit), Seq(src.referenceDebit))
refSinks += {
val syncReg = Reg(Bool())
sink.reference := syncReg
syncReg
}
}
@chiselName
class ReferenceImpl {
val totalCredit = Reg(UInt(32.W))
val totalDebit = Reg(UInt(32.W))
totalCredit := totalCredit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffCredits))(_ + _).last
totalDebit := totalDebit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffDebits))(_ + _).last
val triggerEnable = totalCredit =/= totalDebit
val creditNext = totalCredit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffCredits))(_ + _).last
val debitNext = totalDebit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffDebits))(_ + _).last
totalCredit := creditNext
totalDebit := debitNext
val triggerEnable = creditNext =/= debitNext
refSinks foreach { _ := triggerEnable }
}
// Reference Trigger Enable
withClock(refClock) { new ReferenceImpl }
}