[gg] Make FindClockSources a Transform; multiclock print synthesis

This commit is contained in:
David Biancolin 2019-12-17 23:24:22 -08:00
parent 1cde84538a
commit 6cfdc30aee
5 changed files with 90 additions and 42 deletions

View File

@ -187,7 +187,7 @@ private[passes] class AssertPass(
// Step 4: Associate each assertion with a source clock
val postWiredState = state.copy(circuit = c.copy(modules = mods), form = MidForm)
val loweredState = Seq(new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl).foldLeft(postWiredState)((state, xform) => xform.transform(state))
val clockMapping = FindClockSources(loweredState, mInfo.allClocks)
val clockMapping = FindClockSources.analyze(loweredState, mInfo.allClocks)
val rootClocks = mInfo.allClocks.map(clockMapping)
// For each clock in clock channel, list associated assert indices

View File

@ -9,8 +9,25 @@ import firrtl.annotations.TargetToken.{OfModule, Instance}
import firrtl.graph.{DiGraph}
import firrtl.analyses.InstanceGraph
object FindClockSources {
private def getSourceClock(moduleGraphs: Map[String,DiGraph[LogicNode]])
import midas.passes.fame.RTRenamer
case class FindClockSourceAnnotation(
target: ReferenceTarget,
originalTarget: Option[ReferenceTarget] = None) extends Annotation {
require(target.module == target.circuit, s"Queried leaf clock ${target} must provide an absolute instance path")
def update(renames: RenameMap): Seq[FindClockSourceAnnotation] =
Seq(this.copy(RTRenamer.exact(renames)(target), originalTarget.orElse(Some(target))))
}
case class ClockSourceAnnotation(queryTarget: ReferenceTarget, source: ReferenceTarget) extends Annotation {
def update(renames: RenameMap): Seq[ClockSourceAnnotation] = Seq(this.copy(queryTarget, RTRenamer.exact(renames)(source)))
}
object FindClockSources extends firrtl.Transform {
def inputForm = LowForm
def outputForm = LowForm
private def getSourceClock(moduleGraphs: Map[String,DiGraph[LogicNode]],
moduleDeps: Map[String, Map[String,String]])
(rT: ReferenceTarget): ReferenceTarget = {
val modulePath = (OfModule(rT.module) +: rT.path.map(_._2)).reverse
val instancePath = (None +: rT.path.map(tuple => Some(tuple._1))).reverse
@ -20,17 +37,22 @@ object FindClockSources {
val (instOpt, module) :: restOfPath = path
val mGraph = moduleGraphs(module.value)
val source = if (mGraph.findSinks(currentNode)) currentNode else mGraph.reachableFrom(currentNode).head
require(source.inst == None, "TODO: Handle clocks that may traverse a submodule hierarchy")
restOfPath match {
case Nil => source
case _ => walkModule(LogicNode(source.name, instOpt.map(_.value)), restOfPath)
(restOfPath, source) match {
// Source is a port on the top level module -> we're done
case (Nil, LogicNode(_, None, _)) => source
// Source is a port on an instance in our current module; prepend it to the path and recurse
case (_, LogicNode(port, Some(instName),_)) =>
val childModule = moduleDeps(module.value)(instName)
walkModule(LogicNode(port), (Some(Instance(instName)), OfModule(childModule)) +: path)
// Source is a port but we are not yet at the top; recurse into parent module
case (nonEmptyPath, _) => walkModule(LogicNode(source.name, instOpt.map(_.value)), nonEmptyPath)
}
}
val sourceClock = walkModule(LogicNode(rT.ref), instancePath.zip(modulePath))
rT.moduleTarget.ref(sourceClock.name)
}
def apply(state: CircuitState, queryTargets: Seq[ReferenceTarget]): Map[ReferenceTarget, ReferenceTarget] = {
def analyze(state: CircuitState, queryTargets: Seq[ReferenceTarget]): Map[ReferenceTarget, ReferenceTarget] = {
queryTargets.foreach(t => {
require(t.component == Nil)
require(t.module == t.circuit, s"Queried leaf clock ${t} must provide an absolute instance path")
@ -52,6 +74,18 @@ object FindClockSources {
val simplifiedSubgraph = subgraph.simplify((clockPorts ++ instClockNodes ++ queriedNodes).toSet)
module -> simplifiedSubgraph
}
(queryTargets.map(qT => qT -> getSourceClock(clockConnectivity.toMap)(qT))).toMap
(queryTargets.map(qT => qT -> getSourceClock(clockConnectivity.toMap, moduleDeps)(qT))).toMap
}
def execute(state: CircuitState): CircuitState = {
val queryAnnotations = state.annotations.collect({ case anno: FindClockSourceAnnotation => anno })
val sourceMappings = analyze(state, queryAnnotations.map(_.target))
val clockSourceAnnotations = queryAnnotations.map(qAnno =>
ClockSourceAnnotation(qAnno.originalTarget.getOrElse(qAnno.target), sourceMappings(qAnno.target)))
val prunedAnnos = state.annotations.flatMap({
case _: FindClockSourceAnnotation => None
case o => Some(o)
})
state.copy(annotations = clockSourceAnnotations ++ prunedAnnos)
}
}

View File

@ -29,11 +29,10 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
private val printMods = new mutable.HashSet[ModuleTarget]()
private val formatStringMap = new mutable.HashMap[ReferenceTarget, String]()
// Generates a bundle to aggregate
// Generates a bundle containing a print's clock, enable, and argument fields
def genPrintBundleType(print: Print): Type = BundleType(Seq(
Field("enable", Default, BoolType)) ++
print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) })
)
Field("clock", Default, ClockType), Field("enable", Default, BoolType)) ++
print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) }))
def getPrintName(p: Print, anno: SynthPrintfAnnotation, ns: Namespace): String = {
// If the user provided a name in the annotation use it; otherwise use the source locator
@ -57,17 +56,17 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
// Takes a single printPort and emits an FCCA for each field
def genFCCAsFromPort(mT: ModuleTarget, p: Port): Seq[FAMEChannelConnectionAnnotation] = {
println(p)
p.tpe match {
case BundleType(fields) =>
fields.map(field =>
FAMEChannelConnectionAnnotation.implicitlyClockedSource(
case BundleType(clockField :: dataFields) =>
dataFields.map(field =>
FAMEChannelConnectionAnnotation.source(
p.name + "_" + field.name,
WireChannel,
clock = Some(mT.ref(p.name).field(clockField.name)),
Seq(mT.ref(p.name).field(field.name))
)
)
case other => Seq()
case other => ???
}
}
@ -75,7 +74,10 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
require(state.annotations.collect({ case t: TopWiringAnnotation => t }).isEmpty,
"CircuitState cannot have existing TopWiring annotations before PrintSynthesis.")
val c = state.circuit
def mTarget(m: Module): ModuleTarget = ModuleTarget(c.main, m.name)
def portRT(p: Port): ReferenceTarget = ModuleTarget(c.main, c.main).ref(p.name)
def portClockRT(p: Port): ReferenceTarget = portRT(p).field("clock")
val modToAnnos = printfAnnos.groupBy(_.mod)
@ -92,12 +94,13 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
def onStmt(annos: Seq[SynthPrintfAnnotation], modNamespace: Namespace)
(s: Statement): Statement = s.map(onStmt(annos, modNamespace)) match {
case p @ Print(_,format,args,_,en) if annos.exists(_.format == format.string) =>
case p @ Print(_,format,args,clk,en) if annos.exists(_.format == format.string) =>
val associatedAnno = annos.find(_.format == format.string).get
val printName = getPrintName(p, associatedAnno, modNamespace)
// Generate an aggregate with all of our arguments; this will be wired out
val wire = DefWire(NoInfo, printName, genPrintBundleType(p))
val enableConnect = Connect(NoInfo, wsub(WRef(wire), s"enable"), en)
val clockConnect = Connect(NoInfo, wsub(WRef(wire), "clock"), clk)
val enableConnect = Connect(NoInfo, wsub(WRef(wire), "enable"), en)
val argumentConnects = (p.args.zipWithIndex).map({ case (arg, idx) =>
Connect(NoInfo,
wsub(WRef(wire), s"args_${idx}"),
@ -106,15 +109,16 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
val printBundleTarget = associatedAnno.mod.ref(printName)
topWiringAnnos += TopWiringAnnotation(printBundleTarget, topWiringPrefix)
formatStringMap(printBundleTarget) = format.serialize
Block(Seq(p, wire, enableConnect) ++ argumentConnects)
Block(Seq(p, wire, clockConnect, enableConnect) ++ argumentConnects)
case s => s
}
// Step 1: Find and replace printfs with stubs
val processedCircuit = c.map(onModule)
// Step 2: Wire out print stubs to top level module
val wiredState = (new TopWiringTransform).execute(state.copy(
circuit = processedCircuit,
annotations = state.annotations ++ topWiringAnnos))
val topModule = wiredState.circuit.modules.find(_.name == wiredState.circuit.main).get
val portMap: Map[String, Port] = topModule.ports.map(port => port.name -> port).toMap
val addedPrintPorts = topLevelOutputs.map({ case ((cname,_,_,path,prefix),_) =>
@ -124,31 +128,39 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
(port, formatString)
})
// Step 3: Using each printf bundle clock, group them into separate clock domains
val findClockSourceAnnos = addedPrintPorts.map({ case (port, _) => FindClockSourceAnnotation(portClockRT(port)) })
val stateToAnalyze = wiredState.copy(annotations = findClockSourceAnnos ++ wiredState.annotations)
val loweredState = Seq(new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
FindClockSources).foldLeft(stateToAnalyze)((state, xform) => xform.transform(state))
val clockMapping = loweredState.annotations.collect({ case ClockSourceAnnotation(qT, source) => qT -> source }).toMap
val groupedPrints = addedPrintPorts.groupBy({ case (port, _) => clockMapping(portClockRT(port)) })
println(s"[MIDAS] total # of prints synthesized: ${addedPrintPorts.size}")
val printRecordAnno = addedPrintPorts match {
case Nil => Seq()
case ports => {
// TODO: Generate sensible channel annotations once we can aggregate wire channels
val portName = topWiringPrefix.stripSuffix("_")
val mT = ModuleTarget(c.main, c.main)
val portRT = mT.ref(portName)
val fccaAnnos = ports.flatMap({ case (port, _) => genFCCAsFromPort(mT, port) })
val bridgeAnno = BridgeIOAnnotation(
target = portRT,
widget = (p: Parameters) => new PrintBridgeModule(topWiringPrefix, addedPrintPorts)(p),
channelNames = fccaAnnos.map(_.globalName)
)
bridgeAnno +: fccaAnnos
}
// Step 4: Generate FCCAs and Bridge Annotations for each clock domain
val printRecordAnnos = for ((clockRT, ports) <- groupedPrints) yield {
// TODO: Generate sensible channel annotations once we can aggregate wire channels
val portName = topWiringPrefix.stripSuffix("_")
val mT = ModuleTarget(c.main, c.main)
val portRT = mT.ref(portName)
val fccaAnnos = ports.flatMap({ case (port, _) => genFCCAsFromPort(mT, port) })
val bridgeAnno = BridgeIOAnnotation(
target = portRT,
widget = (p: Parameters) => new PrintBridgeModule(topWiringPrefix, ports)(p),
channelNames = fccaAnnos.map(_.globalName)
)
bridgeAnno +: fccaAnnos
}
// Remove added TopWiringAnnotations to prevent being reconsumed by a downstream pass
val cleanedAnnotations = wiredState.annotations.flatMap({
case TopWiringAnnotation(_,_) => None
case otherAnno => Some(otherAnno)
})
wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnno)
wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnnos.toSeq.flatten)
}
def execute(state: CircuitState): CircuitState = {

View File

@ -17,11 +17,13 @@ class PrintRecord(portType: firrtl.ir.BundleType, val formatString: String) exte
def regenLeafType(tpe: firrtl.ir.Type): Data = tpe match {
case firrtl.ir.UIntType(width: firrtl.ir.IntWidth) => UInt(width.width.toInt.W)
case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W)
case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W)
case badType => throw new RuntimeException(s"Unexpected type in PrintBundle: ${badType}")
}
val args: Seq[(String, Data)] = portType.fields.collect({
case firrtl.ir.Field(name, _, tpe) if name != "enable" => (name -> Output(regenLeafType(tpe)))
case firrtl.ir.Field(name, _, tpe) if name != "enable" && name != "clock" =>
(name -> Output(regenLeafType(tpe)))
})
val enable = Output(Bool())

View File

@ -16,7 +16,7 @@ class DefaultF1Config extends Config(new Config((site, here, up) => {
case DesiredHostFrequency => 75
case SynthAsserts => true
case midas.GenerateMultiCycleRamModels => true
case SynthPrints => false
case SynthPrints => true
}) ++ new Config(new firesim.configs.WithEC2F1Artefacts ++ new WithDefaultMemModel ++ new midas.F1Config))
class PointerChaserConfig extends Config((site, here, up) => {