WIP on multiclock FAME transform

This commit is contained in:
Albert Magyar 2019-11-10 21:14:52 -08:00
parent e151aa56b1
commit 854096f1f3
6 changed files with 112 additions and 135 deletions

View File

@ -26,12 +26,13 @@ import annotations._
*/
case class FAMEChannelPortsAnnotation(
localName: String,
clockPort: Option[ReferenceTarget],
ports: Seq[ReferenceTarget]) extends Annotation {
def update(renames: RenameMap): Seq[Annotation] = {
val renamer = RTRenamer.exact(renames)
Seq(FAMEChannelPortsAnnotation(localName, ports.map(renamer)))
Seq(FAMEChannelPortsAnnotation(localName, clockPort.map(renamer), ports.map(renamer)))
}
override def getTargets: Seq[ReferenceTarget] = ports
override def getTargets: Seq[ReferenceTarget] = clockPort ++: ports
}
/**

View File

@ -25,13 +25,17 @@ import midas.passes._
trait FAME1Channel {
def name: String
def direction: Direction
def clockDomainEnable: Port
def ports: Seq[Port]
def firedReg: DefRegister
def tpe: Type = FAMEChannelAnalysis.getHostDecoupledChannelType(name, ports)
def portName: String
def asPort: Port = Port(NoInfo, portName, direction, tpe)
def isReady: Expression = WSubField(WRef(asPort), "ready", BoolType)
def isValid: Expression = WSubField(WRef(asPort), "valid", BoolType)
def isFiring: Expression = Reduce.and(Seq(isReady, isValid))
def isFiring: Expression = And(isReady, isValid)
def isFired = WRef(firedReg)
def isFiredOrFiring = Or(isFired, isFiring)
def replacePortRef(wr: WRef): WSubField = {
if (ports.size > 1) {
WSubField(WSubField(WRef(asPort), "bits"), FAMEChannelAnalysis.removeCommonPrefix(wr.name, name)._1)
@ -39,34 +43,24 @@ trait FAME1Channel {
WSubField(WRef(asPort), "bits")
}
}
}
case class FAME1InputChannel(val name: String, val ports: Seq[Port]) extends FAME1Channel {
val direction = Input
val portName = s"${name}_sink"
def genTokenLogic(finishing: WRef): Seq[Statement] = {
Seq(Connect(NoInfo, isReady, finishing))
def updateFiredReg(finishing: WRef): Seq[Statement] = {
Connect(NoInfo, isFired, Mux(finishing, Negate(WRef(clockDomainEnable)), isFiredOrFiring, BoolType))
}
}
case class FAME1OutputChannel(val name: String, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel {
case class FAME1InputChannel(val name: String, val clockDomainEnable: Port, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel {
val direction = Input
val portName = s"${name}_sink"
def setReady(finishing: WRef): Statement = {
Connect(NoInfo, isReady, And(finishing, Negate(isFired)))
}
}
case class FAME1OutputChannel(val name: String, val clockDomainEnable: Port, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel {
val direction = Output
val portName = s"${name}_source"
val isFired = WRef(firedReg)
val isFiredOrFiring = Reduce.or(Seq(isFired, isFiring))
def genTokenLogic(finishing: WRef, ccDeps: Iterable[FAME1InputChannel]): Seq[Statement] = {
val regUpdate = Connect(
NoInfo,
isFired,
Mux(finishing,
UIntLiteral(0, IntWidth(1)),
isFiredOrFiring,
BoolType))
val setValid = Connect(
NoInfo,
isValid,
Reduce.and(ccDeps.map(_.isValid) ++ Seq(Negate(isFired))))
Seq(regUpdate, setValid)
def setValid(finishing: WRef, ccDeps: Iterable[FAME1InputChannel]): Statement = {
Connect(NoInfo, isValid, And.reduce(ccDeps.map(_.isValid) :+ Negate(isFired)))
}
}
@ -76,102 +70,107 @@ object ChannelCCDependencyGraph {
}
}
// Multi-clock timestep:
// When finishing is high, dequeue token from clock channel
// - Use to initialize isFired for all channels (with negation)
// - Finishing is gated with clock channel valid
object FAMEModuleTransformer {
def apply(m: Module, analysis: FAMEChannelAnalysis)(implicit triggerName: String): Module = {
// Step 0: Special signals & bookkeeping
def apply(m: Module, analysis: FAMEChannelAnalysis): Module = {
// Step 0: Bookkeeping for present naming and port structure conventions
implicit val ns = Namespace(m)
val clocks = m.ports.filter(_.tpe == ClockType)
// TODO: turn this back to == 1
val portsByName = m.ports.map(p => p.name -> p).toMap
assert(ns.tryName("hostClock") && ns.tryName("hostReset")) // for now, honor this convention
assert(clocks.length >= 1)
val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head) // TODO: naming convention for host clock
val hostReset = HostReset.makePort(ns)
def createHostReg(name: String = "host", width: Width = IntWidth(1)): DefRegister = {
new DefRegister(NoInfo, ns.newName(name), UIntType(width), WRef(hostClock), WRef(hostReset), UIntLiteral(0, width))
// Multi-clock management step 1: Add host clock + reset ports, finishing wire
val hostReset = Port(NoInfo, "hostReset", Input, BoolType)
val hostClock = Port(NoInfo, "hostClock", Input, ClockType)
val finishing = DefWire(NoInfo, ns.newName(triggerName), BoolType) // TODO: can this be a WrappedComponent
def createHostReg(name: String = "host", width: Width = IntWidth(1), resetVal: Expression = UIntLiteral(0, width)): DefRegister = {
new DefRegister(NoInfo, ns.newName(name), UIntType(width), WRef(hostClock), WRef(hostReset), resetVal)
}
val finishing = DefWire(NoInfo, ns.newName(triggerName), BoolType)
val gateTargetClock = true
val buf = InstanceInfo(DefineAbstractClockGate.blackbox).connect("I", WRef(hostClock)).connect("CE", WRef(finishing))
val targetClock = SignalInfo(buf.decl, buf.assigns, WSubField(buf.ref, "O", ClockType, MALE))
// Multi-clock management step 2: Convert all clock ports to enables of same name
val targetClockEns = clocks.map(tpe = BoolType)
// Step 1: Build channels
val mTarget = ModuleTarget(analysis.circuit.main, m.name)
// Multi-clock management step 3: Generate clock buffers for all target clocks
val targetClockBufs = targetClockEns.map { en =>
val enableReg = createHostReg(s"${en.name}_enabled", resetVal = UIntLiteral(1, 1))
val buf = WDefInstance(DefineAbstractClockGate.blackbox.name, ns.newName(s"${en.name}_buffer"))
val connects = Seq(
Connect(NoInfo, WRef(enableReg), Mux(WRef(finishing), WRef(en), WRef(enableReg), BoolType)),
Connect(NoInfo, WSubField(WRef(buf), "I"), WRef(hostClock)),
Connect(NoInfo, WSubField(WRef(buf), "CE"), And(WRef(enableReg), WRef(finishing))))
SignalInfo(Block(Seq(enableReg, buf)), connects, WSubField(WRef(buf), "O", ClockType, MALE))
}
// Multi-clock management step 4: Generate target clock substitution map
val replaceClocksMap = (targetClockEns.map(p => we(p)) zip targetClockBufs.map(_.ref)).toMap
// LI-BDN transformation step 1: Build channels
val portDeps = analysis.connectivity(m.name)
val inChannels = (analysis.modelInputChannelPortMap(mTarget)).map({
case(cName, ports) => new FAME1InputChannel(cName, ports)
case(cName, Some(clock), ports) =>
val clockDomainEnable = portsByName(portDeps(clock.name)).copy(tpe = BoolType)
new FAME1InputChannel(cName, clockDomainEnable, ports)
case (_, None, _) => ??? // clocks are currently mandatory in channels
})
val inChannelMap = new LinkedHashMap[String, FAME1InputChannel] ++
(inChannels.flatMap(c => c.ports.map(p => (p.name, c))))
val outChannels = analysis.modelOutputChannelPortMap(mTarget).map({
case(cName, ports) =>
case(cName, Some(clock), ports) =>
val clockDomainEnable = portsByName(portDeps(clock.name)).copy(tpe = BoolType)
val firedReg = createHostReg(name = ns.newName(s"${cName}_fired"))
new FAME1OutputChannel(cName, ports, firedReg)
new FAME1OutputChannel(cName, clockDomainEnable, ports, firedReg)
case (_, None, _) => ??? // clocks are currently mandatory in channels
})
val outChannelMap = new LinkedHashMap[String, FAME1OutputChannel] ++
(outChannels.flatMap(c => c.ports.map(p => (p.name, c))))
val decls = Seq(finishing) ++ outChannels.map(_.firedReg)
// Step 2: Find combinational dependencies
val ccChecker = new firrtl.transforms.CheckCombLoops
val portDeps = analysis.connectivity(m.name)
// LI-BDN transformation step 2: find combinational dependencies among channels
val ccDeps = new LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]]
portDeps.getEdgeMap.collect({ case (o, iSet) if outChannelMap.contains(o) =>
// Only add input channels, since output might depend on output RHS ref
ccDeps.getOrElseUpdate(outChannelMap(o), new LinkedHashSet[FAME1InputChannel]) ++= iSet.flatMap(inChannelMap.get(_))
})
// Step 3: transform ports
val transformedPorts = clocks ++ Seq(hostReset) ++ inChannels.map(_.asPort) ++ outChannels.map(_.asPort)
// LI-BDN transformation step 3: transform ports (includes new clock ports)
val transformedPorts = Seq(hostClock, hostReset) ++ targetClockEns ++ (inChannels ++ outChannels).map(_.asPort)
// Step 4: Replace refs and gate state updates
def onExpr(expr: Expression): Expression = expr.map(onExpr) match {
case iWR @ WRef(name, tpe, PortKind, MALE) if tpe != ClockType =>
// LI-BDN transformation step 4: replace port and clock references and gate state updates
def onExpr(expr: Expression): Expression = expr match {
case wr @ WRef(name, tpe, PortKind, MALE) if tpe != ClockType =>
// Generally MALE references to ports will be input channels, but RTL may use
// an assignment to an output port as something akin to a wire, so check output ports too.
inChannelMap.getOrElse(name, outChannelMap(name)).replacePortRef(iWR)
case oWR @ WRef(name, tpe, PortKind, FEMALE) if tpe != ClockType =>
outChannelMap(name).replacePortRef(oWR)
case wr: WRef if wr.name == hostClock.name =>
// Replace host clock references with target clock references
targetClock.ref
case e => e
case cWR @ WRef(name, ClockType, PortKind, MALE) =>
replaceClocksMap(wr)
case e => e map onExpr
}
/*
* A target state trigger is only needed if clocks are not gated. When using true clock gating,
* the transform still programmatically adds an extra enable signal to the state updates, so we
* pass in a constant one as the value of this enable signal. This spurious condition gets
* optimized away during ConstantPropagation. This homogeneity is helpful since a transformed
* module might be instantiated within multiple top-level simulation models, some of which may
* rely true clock gating (if their corresponding target modules contain blackboxes) and some of
* which may not.
*/
val targetStateTrigger = if (gateTargetClock) one else WRef(finishing)
val transformedStmts = Seq(m.body.map(_.map(onExpr)))
def onStmt(stmt: Statement): Statement = stmt.map(onStmt).map(onExpr) match {
case conn @ Connect(info, lhs, _) if (kind(lhs) == RegKind) =>
Conditionally(info, targetStateTrigger, conn, EmptyStmt)
case mem: DefMemory => PatientMemTransformer(mem, targetStateTrigger, WRef(hostClock), ns)
case wi: WDefInstance if analysis.syncNativeModules.contains(analysis.moduleTarget(wi)) =>
new Block(Seq(wi, Connect(wi.info, WSubField(WRef(wi), triggerName), targetStateTrigger)))
case s: Stop => s.copy(en = DoPrim(PrimOps.And, Seq(targetStateTrigger, s.en), Seq.empty, BoolType))
case p: Print => p.copy(en = DoPrim(PrimOps.And, Seq(targetStateTrigger, p.en), Seq.empty, BoolType))
case s => s
}
// LI-BDN transformation step 5: add firing rules for output channels, trigger end of cycle
// This is modified for multi-clock, as each channel fires only when associated clock is enabled
val allFiredOrFiring = And.reduce(outChannels.map(_.isFiredOrFiring) ++ inChannels.map(_.isValid))
val clockChannelReady = ???
val clockChannelValid = ???
val transformedStmts = Seq(m.body.map(onStmt))
// Step 5: Add firing rules for output channels, trigger end of cycle
val ruleStmts = new mutable.ArrayBuffer[Statement]
ruleStmts ++= outChannels.flatMap(o => o.genTokenLogic(WRef(finishing), ccDeps(o)))
ruleStmts ++= inChannels.flatMap(i => i.genTokenLogic(WRef(finishing)))
ruleStmts += Connect(NoInfo, WRef(finishing),
Reduce.and(outChannels.map(_.isFiredOrFiring) ++ inChannels.map(_.isValid)))
val channelStateRules = (inChannels ++ outChannels).map(c => c.updateFiredReg)
val inputRules = inChannels.flatMap(i => i.setReady(WRef(finishing)))
val outputRules = outChannels.flatMap(o => o.setValid(WRef(finishing), ccDeps(o)))
val topRules = Seq(
Connect(NoInfo, clockChannelReady, allFiredOrFiring),
Connect(NoInfo, WRef(finishing), And(allFiredOrFiring, clockChannelValid)))
// Statements have to be conservatively ordered to satisfy declaration order
val allStmts = targetClock.decl +: (decls ++ transformedStmts ++ ruleStmts) :+ targetClock.assigns
Module(m.info, m.name, transformedPorts, new Block(allStmts))
val decls = finishing +: targetClockBufs.map(_.decl) ++ (inChannels ++ outChannels).map(_.firedReg)
val assigns = targetClockBufs.map(_.assigns) ++ channelStateRules ++ inputRules ++ outputRules ++ topRules
Module(m.info, m.name, transformedPorts, Block(decls ++: transformedStmts +: assigns))
}
}

View File

@ -135,6 +135,7 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F
val transformedSinks = new LinkedHashSet[String]
val transformedSources = new LinkedHashSet[String]
val clockPort = new LinkedHashMap[String, ReferenceTarget]
val sinkModel = new LinkedHashMap[String, InstanceTarget]
val sourceModel = new LinkedHashMap[String, InstanceTarget]
val sinkPorts = new LinkedHashMap[String, Seq[ReferenceTarget]]
@ -186,17 +187,17 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F
}
// Looks up all FAMEChannelPortAnnotations bound to a model module, to generate a Map
// from channel name to port list
private def genModelChannelPortMap(direction: Option[Direction])(mTarget: ModuleTarget): Map[String, Seq[Port]] = {
// from channel name to clock option and port list
private def genModelChannelPortMap(direction: Option[Direction])(mTarget: ModuleTarget): Map[String, (Option[Port], Seq[Port])] = {
modelPorts(mTarget).collect({
case FAMEChannelPortsAnnotation(name, ports) if direction == None || portNodes(ports.head).direction == direction.get =>
(name, ports.map(portNodes(_)))
case FAMEChannelPortsAnnotation(name, clock, ports) if direction == None || portNodes(ports.head).direction == direction.get =>
(name, clock.map(portNodes(_)), ports.map(portNodes(_)))
}).toMap
}
def modelInputChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(Some(Input))
def modelOutputChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(Some(Output))
def modelChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(None)
def modelInputChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(Some(Input))
def modelOutputChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(Some(Output))
def modelChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(None)
def getSinkHostDecoupledChannelType(cName: String): Type = {
FAMEChannelAnalysis.getHostDecoupledChannelType(cName, sinkPorts(cName).map(portNodes(_)))
@ -230,9 +231,10 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F
Some(cName, ports)
}).toMap
val clockMap = dedupPortLists(clockPortByChannel(mTarget))
val inputPortMap = dedupPortLists(inputPortsByChannel(mTarget))
val outputPortMap = dedupPortLists(outputPortsByChannel(mTarget))
val completePortMap = inputPortMap ++ outputPortMap
val completePortMap = (inputPortMap ++ outputPortMap).map { case (cName, ports) => cName -> (ports, clockMap(cName)) }
}
lazy val modulePortDedupers = transformedModules.map((mT: ModuleTarget) => new ModulePortDeduper(mT))

View File

@ -21,7 +21,7 @@ class InferModelPorts extends Transform {
val analysis = new FAMEChannelAnalysis(state, FAME1Transform)
val cTarget = CircuitTarget(state.circuit.main)
val modelChannelPortsAnnos = analysis.modulePortDedupers.flatMap(deduper =>
deduper.completePortMap.flatMap({ case (cName, ports) => Seq(
deduper.completePortMap.flatMap({ case (cName, (ports, clk)) => Seq(
FAMEChannelPortsAnnotation(cName, ports.map(p => deduper.mTarget.ref(p.name)))) ++
// Label all the channel ports with don't touch so as to prevent
// annotation renaming from breaking downstream

View File

@ -48,11 +48,18 @@ object Negate {
def apply(arg: Expression): Expression = DoPrim(PrimOps.Not, Seq(arg), Seq.empty, arg.tpe)
}
object Reduce {
private def _reduce(op: PrimOp, args: Iterable[Expression]): Expression = {
args.tail.foldLeft(args.head){ (l, r) => DoPrim(op, Seq(l, r), Seq.empty, UIntType(IntWidth(1))) }
sealed trait BinaryBooleanOp {
def op: PrimOp
def apply(l: Expression, r: Expression) = DoPrim(op, Seq(l, r))
def reduce(args: Iterable[Expression]): DoPrim = {
args.tail.foldLeft(args.head){ (l, r) => apply(l, r) }
}
def and(args: Iterable[Expression]): Expression = _reduce(PrimOps.And, args)
def or(args: Iterable[Expression]): Expression = _reduce(PrimOps.Or, args)
}
object And extends BinaryBooleanOp {
val op = PrimOps.And
}
object Or extends BinaryBooleanOp {
val op = PrimOps.Or
}

View File

@ -15,58 +15,26 @@ import scala.language.implicitConversions
package object passes {
/**
* A utility for keeping statements defining and connecting signals to a piece of hardware
* together with a reference to the component. This is useful for passes that insert hardware,
* since the "collateral" of that object can be kept in one place.
*/
trait WrappedComponent {
val decl: Statement
val assigns: Statement
val ref: Expression
}
case class SignalInfo(decl: Statement, assigns: Statement, rhsRef: Expression)
/**
* Holds the definition of a signal along with the statements that assign to it and its reference.
*/
case class SignalInfo(decl: Statement, assigns: Statement, ref: Expression) extends WrappedComponent
/**
* A utility for creating a wire that "echoes" the value of an existing expression.
*/
object PassThru {
def apply(source: WRef)(implicit ns: Namespace): SignalInfo = apply(source, source.name)
def apply(source: WRef)(implicit ns: Namespace): SignalInfo = echo(source, source.name)
def apply(source: WRef, suggestedName: String)(implicit ns: Namespace): SignalInfo = {
val decl = DefWire(NoInfo, ns.newName(suggestedName), source.tpe)
val ref = WRef(decl)
val rhsRef = WRef(decl)
SignalInfo(decl, Connect(NoInfo, WRef(decl), source), ref)
}
}
object InstanceInfo {
def apply(m: DefModule)(implicit ns: Namespace): InstanceInfo = {
val inst = fame.Instantiate(m, ns.newName(m.name))
InstanceInfo(inst, Block(Nil), WRef(inst))
}
}
/**
* Holds the declaration of an instance, along with the set of statements that create connections
* to its ports, along with a reference to the instance.
*/
case class InstanceInfo(decl: WDefInstance, assigns: Block, ref: WRef) extends WrappedComponent {
def addAssign(s: Statement): InstanceInfo = {
copy(assigns = Block(assigns.stmts :+ s))
}
def connect(pName: String, rhs: Expression): InstanceInfo = {
addAssign(Connect(NoInfo, WSubField(ref, pName), rhs))
}
def connect(lhs: Expression, pName: String): InstanceInfo = {
addAssign(Connect(NoInfo, lhs, WSubField(ref, pName)))
}
}
/**
* This pass ensures that the AbstractClockGate blackbox is defined in a circuit, so that it can
* later be instantiated. The blackbox clock gate has the following signature: