Add logic to deal with clock channel connections

This commit is contained in:
Albert Magyar 2019-11-13 14:27:21 -08:00
parent d95c3d2e83
commit dbbe9d657c
5 changed files with 135 additions and 67 deletions

View File

@ -51,9 +51,12 @@ private[midas] class MidasTransforms(
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
new EmitFirrtl("pre-bridge-extraction.fir"),
new fame.EmitFAMEAnnotations("pre-bridge-extraction.json"),
new BridgeExtraction,
new ResolveAndCheck,
new fame.EmitFAMEAnnotations("post-bridge-extraction.json"),
new EmitFirrtl("post-bridge-extraction.fir"),
new ResolveAndCheck,
new MiddleFirrtlToLowFirrtl,
new fame.WrapTop,
new ResolveAndCheck,
@ -66,9 +69,12 @@ private[midas] class MidasTransforms(
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
new fame.FAMEDefaults,
new EmitFirrtl("pre-channel-excision.fir"),
new fame.ChannelExcision,
new fame.InferModelPorts,
new EmitFirrtl("post-channel-excision.fir"),
new fame.EmitFAMEAnnotations("post-channel-excision.json"),
new fame.InferModelPorts,
new fame.EmitFAMEAnnotations("post-infer-model-ports.json"),
new fame.FAMETransform,
DefineAbstractClockGate,
new EmitFirrtl("post-fame-transform.fir"),

View File

@ -3,6 +3,13 @@ package midas.passes.fame
import firrtl._
import annotations._
/**
* A mixed-in ancestor trait for all FAME annotations, useful for type-casing.
*/
trait FAMEAnnotation {
this: Annotation =>
}
/**
* An annotation that describes the ports that constitute one channel
* from the perspective of a particular module that will be replaced
@ -27,7 +34,7 @@ import annotations._
case class FAMEChannelPortsAnnotation(
localName: String,
clockPort: Option[ReferenceTarget],
ports: Seq[ReferenceTarget]) extends Annotation {
ports: Seq[ReferenceTarget]) extends Annotation with FAMEAnnotation {
def update(renames: RenameMap): Seq[Annotation] = {
val renamer = RTRenamer.exact(renames)
Seq(FAMEChannelPortsAnnotation(localName, clockPort.map(renamer), ports.map(renamer)))
@ -53,7 +60,7 @@ case class FAMEChannelConnectionAnnotation(
channelInfo: FAMEChannelInfo,
clock: Option[ReferenceTarget],
sources: Option[Seq[ReferenceTarget]],
sinks: Option[Seq[ReferenceTarget]]) extends Annotation with HasSerializationHints {
sinks: Option[Seq[ReferenceTarget]]) extends Annotation with FAMEAnnotation with HasSerializationHints {
def update(renames: RenameMap): Seq[Annotation] = {
val renamer = RTRenamer.exact(renames)
Seq(FAMEChannelConnectionAnnotation(globalName, channelInfo.update(renames), clock.map(renamer), sources.map(_.map(renamer)), sinks.map(_.map(renamer))))
@ -62,14 +69,15 @@ case class FAMEChannelConnectionAnnotation(
def getBridgeModule(): String = sources.getOrElse(sinks.get).head.module
// TODO: Maybe clocks should become associated with module port here?
// TODO (David): Maybe clocks should become associated with module port here?
// If so, the pass calling this would handle the clocks.
// Otherwise, give this a different name to make it clear that it's not moving everything
// POSSIBLY FIXED (Albert): I included the clock in the renamed targets
def moveFromBridge(portName: String): FAMEChannelConnectionAnnotation = {
def updateRT(rT: ReferenceTarget): ReferenceTarget = ModuleTarget(rT.circuit, rT.circuit).ref(portName).field(rT.ref)
require(sources == None || sinks == None, "Bridge-connected channels cannot loopback")
val rTs = sources.getOrElse(sinks.get) ++ (channelInfo match {
val rTs = sources.getOrElse(sinks.get) ++ clock ++ (channelInfo match {
case i: DecoupledForwardChannel => Seq(i.readySink.getOrElse(i.readySource.get))
case other => Seq()
})
@ -146,7 +154,7 @@ case object DecoupledReverseChannel extends FAMEChannelInfo
case object TargetClockChannel extends FAMEChannelInfo
/**
* Indicates that a channel connection is the reverse (ready) half of
* Indicates that a channel connection is the forward (valid) half of
* a decoupled target connection.
*
* @param readySink sink port component of the corresponding reverse channel
@ -186,7 +194,8 @@ object DecoupledForwardChannel {
/**
* Indicates that a particular instance is a FAME Model
*/
case class FAMEModelAnnotation(target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] {
case class FAMEModelAnnotation(
target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] with FAMEAnnotation {
def targets = Seq(target)
def duplicate(n: InstanceTarget) = this.copy(n)
}
@ -217,7 +226,9 @@ case object FAME1Transform extends FAMETransformType
* this is a ModuleTarget, all instances at the top level will be
* transformed identically.
*/
case class FAMETransformAnnotation(transformType: FAMETransformType, target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] {
case class FAMETransformAnnotation(
transformType: FAMETransformType,
target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] with FAMEAnnotation {
def targets = Seq(target)
def duplicate(n: ModuleTarget) = this.copy(transformType, n)
}
@ -232,12 +243,13 @@ case class FAMETransformAnnotation(transformType: FAMETransformType, target: Mod
* be a *local* instance target, as all instances of the parent
* module will be transformed identically.
*/
case class PromoteSubmoduleAnnotation(target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] {
case class PromoteSubmoduleAnnotation(
target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] with FAMEAnnotation {
def targets = Seq(target)
def duplicate(n: InstanceTarget) = this.copy(n)
}
abstract class FAMEGlobalSignal extends SingleTargetAnnotation[ReferenceTarget] {
abstract class FAMEGlobalSignal extends SingleTargetAnnotation[ReferenceTarget] with FAMEAnnotation {
val target: ReferenceTarget
def targets = Seq(target)
def duplicate(n: ReferenceTarget): FAMEGlobalSignal
@ -251,7 +263,7 @@ case class FAMEHostReset(target: ReferenceTarget) extends FAMEGlobalSignal {
def duplicate(t: ReferenceTarget): FAMEHostReset = this.copy(t)
}
abstract class MemPortAnnotation extends Annotation {
abstract class MemPortAnnotation extends Annotation with FAMEAnnotation {
val en: ReferenceTarget
val addr: ReferenceTarget
}
@ -320,3 +332,24 @@ case class ModelReadWritePort(
}
override def getTargets: Seq[ReferenceTarget] = Seq(wmode, rdata, wdata, wmask, addr, en)
}
/**
* A pass that dumps all FAME annotations to a file for debugging.
*/
class EmitFAMEAnnotations(fileName: String) extends firrtl.Transform {
import firrtl.options.TargetDirAnnotation
def inputForm = UnknownForm
def outputForm = UnknownForm
override def name = s"[MIDAS] Debugging FAME Annotation Emission Pass: $fileName"
def execute(state: CircuitState) = {
val targetDir = state.annotations.collectFirst { case TargetDirAnnotation(dir) => dir }
val dirName = targetDir.getOrElse(".")
val outputFile = new java.io.PrintWriter(s"${dirName}/${fileName}")
val fameAnnos = state.annotations.collect { case fa: FAMEAnnotation => fa }
outputFile.write(JsonProtocol.serialize(fameAnnos))
outputFile.close()
state
}
}

View File

@ -25,17 +25,13 @@ import midas.passes._
trait FAME1Channel {
def name: String
def direction: Direction
def clockDomainEnable: Port
def ports: Seq[Port]
def firedReg: DefRegister
def tpe: Type = FAMEChannelAnalysis.getHostDecoupledChannelType(name, ports)
def portName: String
def asPort: Port = Port(NoInfo, portName, direction, tpe)
def isReady: Expression = WSubField(WRef(asPort), "ready", BoolType)
def isValid: Expression = WSubField(WRef(asPort), "valid", BoolType)
def isFiring: Expression = And(isReady, isValid)
def isFired = WRef(firedReg)
def isFiredOrFiring = Or(isFired, isFiring)
def replacePortRef(wr: WRef): WSubField = {
if (ports.size > 1) {
WSubField(WSubField(WRef(asPort), "bits"), FAMEChannelAnalysis.removeCommonPrefix(wr.name, name)._1)
@ -43,12 +39,28 @@ trait FAME1Channel {
WSubField(WRef(asPort), "bits")
}
}
}
case class FAME1ClockChannel(name: String, ports: Seq[Port]) extends FAME1Channel {
val direction = Input
val portName = s"${name}_sink" // clock channel port on model sinks clocks
}
trait FAME1DataChannel extends FAME1Channel {
def clockDomainEnable: Port
def firedReg: DefRegister
def isFired = WRef(firedReg)
def isFiredOrFiring = Or(isFired, isFiring)
def updateFiredReg(finishing: WRef): Statement = {
Connect(NoInfo, isFired, Mux(finishing, Negate(WRef(clockDomainEnable)), isFiredOrFiring, BoolType))
}
}
case class FAME1InputChannel(val name: String, val clockDomainEnable: Port, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel {
case class FAME1InputChannel(
name: String,
clockDomainEnable: Port,
ports: Seq[Port],
firedReg: DefRegister) extends FAME1DataChannel {
val direction = Input
val portName = s"${name}_sink"
def setReady(finishing: WRef): Statement = {
@ -56,7 +68,11 @@ case class FAME1InputChannel(val name: String, val clockDomainEnable: Port, val
}
}
case class FAME1OutputChannel(val name: String, val clockDomainEnable: Port, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel {
case class FAME1OutputChannel(
name: String,
clockDomainEnable: Port,
ports: Seq[Port],
firedReg: DefRegister) extends FAME1DataChannel {
val direction = Output
val portName = s"${name}_source"
def setValid(finishing: WRef, ccDeps: Iterable[FAME1InputChannel]): Statement = {
@ -64,12 +80,6 @@ case class FAME1OutputChannel(val name: String, val clockDomainEnable: Port, val
}
}
object ChannelCCDependencyGraph {
def apply(m: Module): LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]] = {
new LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]]
}
}
// Multi-clock timestep:
// When finishing is high, dequeue token from clock channel
// - Use to initialize isFired for all channels (with negation)
@ -78,6 +88,7 @@ object FAMEModuleTransformer {
def apply(m: Module, analysis: FAMEChannelAnalysis): Module = {
// Step 0: Bookkeeping for port structure conventions
implicit val ns = Namespace(m)
val mTarget = ModuleTarget(analysis.circuit.main, m.name)
val clocks: Seq[Port] = m.ports.filter(_.tpe == ClockType)
val portsByName = m.ports.map(p => p.name -> p).toMap
assert(clocks.length >= 1)
@ -93,11 +104,19 @@ object FAMEModuleTransformer {
DefRegister(NoInfo, ns.newName(suggestName), BoolType, WRef(hostClock), WRef(hostReset), resetVal)
}
// Multi-clock management step 2: Convert all clock ports to enables of same name
val targetClockEns: Seq[Port] = clocks.map(_.copy(tpe = BoolType))
// Multi-clock management step 2: Build clock flags and clock channel
def isClockChannel(info: (String, (Option[Port], Seq[Port]))) = info match {
case (_, (clk, ports)) => clk.isEmpty && ports.forall(_.tpe == ClockType)
}
// Multi-clock management step 3: Generate clock buffers for all target clocks
val targetClockBufs: Seq[SignalInfo] = targetClockEns.map { en =>
val clockChannel = analysis.modelInputChannelPortMap(mTarget).find(isClockChannel) match {
case Some((name, (None, ports))) => FAME1ClockChannel(name, ports.map(_.copy(tpe = BoolType)))
case Some(_) => ??? // Clock channel cannot have
case None => ??? // Clock channel is mandatory for now
}
// Multi-clock management step 4: Generate clock buffers for all target clocks
val targetClockBufs: Seq[SignalInfo] = clockChannel.ports.map { en =>
val enableReg = hostFlagReg(s"${en.name}_enabled", resetVal = UIntLiteral(1))
val buf = WDefInstance(DefineAbstractClockGate.blackbox.name, ns.newName(s"${en.name}_buffer"))
val connects = Block(Seq(
@ -107,38 +126,40 @@ object FAMEModuleTransformer {
SignalInfo(Block(Seq(enableReg, buf)), connects, WSubField(WRef(buf), "O", ClockType, MALE))
}
// Multi-clock management step 4: Generate target clock substitution map
// Multi-clock management step 5: Generate target clock substitution map
def asWE(p: Port) = WrappedExpression.we(WRef(p))
val replaceClocksMap = (targetClockEns.map(p => asWE(p)) zip targetClockBufs.map(_.ref)).toMap
val replaceClocksMap = (clockChannel.ports.map(p => asWE(p)) zip targetClockBufs.map(_.ref)).toMap
// LI-BDN transformation step 1: Build channels
// TODO: get rid of the analysis calls; we just need connectivity & annotations
// Would be shorter, could merge common code for in/out channels
val portDeps = analysis.connectivity(m.name)
val mTarget = ModuleTarget(analysis.circuit.main, m.name)
val inChannels = (analysis.modelInputChannelPortMap(mTarget)).map({
case(cName, (Some(clock), ports)) =>
val sourceClocks = portDeps.getEdges(clock.name)
assert(sourceClocks.size == 1) // must be driven by one clock input
val clockDomainEnable = portsByName(sourceClocks.head).copy(tpe = BoolType)
val firedReg = hostFlagReg(suggestName = ns.newName(s"${cName}_fired"))
new FAME1InputChannel(cName, clockDomainEnable, ports, firedReg)
case (_, (None, _)) => ??? // clocks are currently mandatory in channels
}).toSeq
val inChannelMap = new LinkedHashMap[String, FAME1InputChannel] ++
(inChannels.flatMap(c => c.ports.map(p => (p.name, c))))
val outChannels = analysis.modelOutputChannelPortMap(mTarget).map({
case(cName, (Some(clock), ports)) =>
def genMetadata(info: (String, (Option[Port], Seq[Port]))) = info match {
case (cName, (Some(clock), ports)) =>
val sourceClocks = portDeps.getEdges(clock.name)
assert(sourceClocks.size == 1) // must be driven by one clock input
val clockDomainEnable = portsByName(sourceClocks.head).copy(tpe = BoolType)
val clkFlag = portsByName(sourceClocks.head).copy(tpe = BoolType)
val firedReg = hostFlagReg(suggestName = ns.newName(s"${cName}_fired"))
new FAME1OutputChannel(cName, clockDomainEnable, ports, firedReg)
case (_, (None, _)) => ??? // clocks are currently mandatory in channels
}).toSeq
val outChannelMap = new LinkedHashMap[String, FAME1OutputChannel] ++
(outChannels.flatMap(c => c.ports.map(p => (p.name, c))))
(cName, clkFlag, ports, firedReg)
case (cName, (None, ports)) =>
println(s"Channel ${cName} has no clock")
ports.foreach { p => println(s" ${p}") }
??? // clock is mandatory for now
}
// LinkedHashMap.from is 2.13-only :(
def stableMap[K, V](contents: Iterable[(K, V)]) = new LinkedHashMap[K, V] ++= contents
// Have to filter out the clock channel from the input channels
val inChannelInfo = analysis.modelInputChannelPortMap(mTarget).filterNot(isClockChannel(_)).toSeq
val inChannelMetadata = inChannelInfo.map(genMetadata(_))
val inChannels = inChannelMetadata.map((FAME1InputChannel.apply _).tupled)
val inChannelMap = stableMap(inChannels.flatMap(c => c.ports.map(p => p.name -> c)))
val outChannelInfo = analysis.modelOutputChannelPortMap(mTarget).toSeq
val outChannelMetadata = outChannelInfo.map(genMetadata(_))
val outChannels = outChannelMetadata.map((FAME1OutputChannel.apply _).tupled)
val outChannelMap = stableMap(outChannels.flatMap(c => c.ports.map(p => p.name -> c)))
// LI-BDN transformation step 2: find combinational dependencies among channels
val ccDeps = new LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]]
@ -148,7 +169,7 @@ object FAMEModuleTransformer {
})
// LI-BDN transformation step 3: transform ports (includes new clock ports)
val transformedPorts = Seq(hostClock, hostReset) ++ targetClockEns ++ (inChannels ++ outChannels).map(_.asPort)
val transformedPorts = hostClock +: hostReset +: (clockChannel +: inChannels ++: outChannels).map(_.asPort)
// LI-BDN transformation step 4: replace port and clock references and gate state updates
def onExpr(expr: Expression): Expression = expr match {
@ -170,15 +191,13 @@ object FAMEModuleTransformer {
// LI-BDN transformation step 5: add firing rules for output channels, trigger end of cycle
// This is modified for multi-clock, as each channel fires only when associated clock is enabled
val allFiredOrFiring = And.reduce(outChannels.map(_.isFiredOrFiring) ++ inChannels.map(_.isValid))
val clockChannelReady = ???
val clockChannelValid = ???
val channelStateRules = (inChannels ++ outChannels).map(c => c.updateFiredReg(WRef(finishing)))
val inputRules = inChannels.map(i => i.setReady(WRef(finishing)))
val outputRules = outChannels.map(o => o.setValid(WRef(finishing), ccDeps(o)))
val topRules = Seq(
Connect(NoInfo, clockChannelReady, allFiredOrFiring),
Connect(NoInfo, WRef(finishing), And(allFiredOrFiring, clockChannelValid)))
Connect(NoInfo, clockChannel.isReady, allFiredOrFiring),
Connect(NoInfo, WRef(finishing), And(allFiredOrFiring, clockChannel.isValid)))
// Statements have to be conservatively ordered to satisfy declaration order
val decls = finishing +: targetClockBufs.map(_.decl) ++: (inChannels ++ outChannels).map(_.firedReg)

View File

@ -179,6 +179,15 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F
val hostReset = state.annotations.collect({ case FAMEHostReset(rt) => rt }).head
private def irPortFromGlobalTarget(rt: ReferenceTarget): Port = {
println(s"Resolving port node from global ref ${rt}")
if (topConnects.contains(rt)) {
println(s"${rt} is connected to ${topConnects(rt)} (in some direction)")
} else {
println(s"Key ${rt} not found, dumping topConnects:")
topConnects.foreach {
case (k, v) => println(s" ${k} <> ${v}")
}
}
portNodes(topConnects(rt).pathlessTarget)
}
@ -234,8 +243,10 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F
private val visitedLeafPort = new LinkedHashSet[Port]()
private val visitedChannel = new LinkedHashMap[(Option[Port], Seq[Port]), String]()
private def channelSharesPorts(ps: (Option[Port], Seq[Port])): Boolean = (ps._1 ++: ps._2).exists(visitedLeafPort(_))
private def channelIsDuplicate(ps: (Option[Port], Seq[Port])): Boolean = visitedChannel.contains(ps)
private def channelSharesPorts(ps: (Option[Port], Seq[Port])): Boolean = ps match {
case (clk, ports) => ports.exists(visitedLeafPort(_)) // clock can be shared
}
private def dedupPortLists(pList: Map[String, (Option[Port], Seq[Port])]): Map[String, (Option[Port], Seq[Port])] = pList.flatMap({
case (cName, (_, Nil)) => throw new RuntimeException(s"Channel ${cName} is empty (has no associated ports)")

View File

@ -25,23 +25,22 @@ class WrapTop extends Transform {
val clocks = topModule.ports.filter(_.tpe == ClockType)
val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head)
val hostReset = HostReset.makePort(topWrapperNS)
val oldCircuitTarget = CircuitTarget(topName)
val topWrapperTarget = ModuleTarget(topWrapperName, topWrapperName)
val topWrapper = Module(NoInfo, topWrapperName, topModule.ports :+ hostReset, Block(topInstance +: portConnections))
val specialPortAnnotations = Seq(FAMEHostClock(topWrapperTarget.ref(hostClock.name)), FAMEHostReset(topWrapperTarget.ref(hostReset.name)))
val renames = RenameMap()
val newCircuit = Circuit(state.circuit.info, topWrapper +: state.circuit.modules, topWrapperName)
// Make channel annotations point at top-level ports
val fccaRenames = RenameMap()
fccaRenames.record(oldCircuitTarget.module(topName), oldCircuitTarget.module(topWrapperName))
val updatedAnnotations = state.annotations.map({
case fca: FAMEChannelConnectionAnnotation =>
fca.copy(sinks = fca.sinks.map(_.map(_.copy(module = topWrapperName))), sources = fca.sources.map(_.map(_.copy(module = topWrapperName))))
case a => a
}).map({ // Also update targets in info fields
case fca @ FAMEChannelConnectionAnnotation(_,info@DecoupledForwardChannel(_,_,_,_),_,_,_) =>
fca.copy(channelInfo = info.copy(
readySink = info.readySink. map(_.copy(module = topWrapperName)),
validSource = info.validSource.map(_.copy(module = topWrapperName)),
readySource = info.readySource.map(_.copy(module = topWrapperName)),
validSink = info.validSink. map(_.copy(module = topWrapperName))))
case fcca: FAMEChannelConnectionAnnotation =>
val renamedInfo = fcca.channelInfo match {
case fwd: DecoupledForwardChannel => fwd.update(fccaRenames)
case info => info
}
fcca.copy(channelInfo = renamedInfo).update(fccaRenames).head // always returns 1
case a => a
})