[gg] Make FindClockSources a Transform; multiclock print synthesis
This commit is contained in:
parent
1cde84538a
commit
6cfdc30aee
|
@ -187,7 +187,7 @@ private[passes] class AssertPass(
|
|||
// Step 4: Associate each assertion with a source clock
|
||||
val postWiredState = state.copy(circuit = c.copy(modules = mods), form = MidForm)
|
||||
val loweredState = Seq(new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl).foldLeft(postWiredState)((state, xform) => xform.transform(state))
|
||||
val clockMapping = FindClockSources(loweredState, mInfo.allClocks)
|
||||
val clockMapping = FindClockSources.analyze(loweredState, mInfo.allClocks)
|
||||
val rootClocks = mInfo.allClocks.map(clockMapping)
|
||||
|
||||
// For each clock in clock channel, list associated assert indices
|
||||
|
|
|
@ -9,8 +9,25 @@ import firrtl.annotations.TargetToken.{OfModule, Instance}
|
|||
import firrtl.graph.{DiGraph}
|
||||
import firrtl.analyses.InstanceGraph
|
||||
|
||||
object FindClockSources {
|
||||
private def getSourceClock(moduleGraphs: Map[String,DiGraph[LogicNode]])
|
||||
import midas.passes.fame.RTRenamer
|
||||
|
||||
case class FindClockSourceAnnotation(
|
||||
target: ReferenceTarget,
|
||||
originalTarget: Option[ReferenceTarget] = None) extends Annotation {
|
||||
require(target.module == target.circuit, s"Queried leaf clock ${target} must provide an absolute instance path")
|
||||
def update(renames: RenameMap): Seq[FindClockSourceAnnotation] =
|
||||
Seq(this.copy(RTRenamer.exact(renames)(target), originalTarget.orElse(Some(target))))
|
||||
}
|
||||
|
||||
case class ClockSourceAnnotation(queryTarget: ReferenceTarget, source: ReferenceTarget) extends Annotation {
|
||||
def update(renames: RenameMap): Seq[ClockSourceAnnotation] = Seq(this.copy(queryTarget, RTRenamer.exact(renames)(source)))
|
||||
}
|
||||
|
||||
object FindClockSources extends firrtl.Transform {
|
||||
def inputForm = LowForm
|
||||
def outputForm = LowForm
|
||||
private def getSourceClock(moduleGraphs: Map[String,DiGraph[LogicNode]],
|
||||
moduleDeps: Map[String, Map[String,String]])
|
||||
(rT: ReferenceTarget): ReferenceTarget = {
|
||||
val modulePath = (OfModule(rT.module) +: rT.path.map(_._2)).reverse
|
||||
val instancePath = (None +: rT.path.map(tuple => Some(tuple._1))).reverse
|
||||
|
@ -20,17 +37,22 @@ object FindClockSources {
|
|||
val (instOpt, module) :: restOfPath = path
|
||||
val mGraph = moduleGraphs(module.value)
|
||||
val source = if (mGraph.findSinks(currentNode)) currentNode else mGraph.reachableFrom(currentNode).head
|
||||
require(source.inst == None, "TODO: Handle clocks that may traverse a submodule hierarchy")
|
||||
restOfPath match {
|
||||
case Nil => source
|
||||
case _ => walkModule(LogicNode(source.name, instOpt.map(_.value)), restOfPath)
|
||||
(restOfPath, source) match {
|
||||
// Source is a port on the top level module -> we're done
|
||||
case (Nil, LogicNode(_, None, _)) => source
|
||||
// Source is a port on an instance in our current module; prepend it to the path and recurse
|
||||
case (_, LogicNode(port, Some(instName),_)) =>
|
||||
val childModule = moduleDeps(module.value)(instName)
|
||||
walkModule(LogicNode(port), (Some(Instance(instName)), OfModule(childModule)) +: path)
|
||||
// Source is a port but we are not yet at the top; recurse into parent module
|
||||
case (nonEmptyPath, _) => walkModule(LogicNode(source.name, instOpt.map(_.value)), nonEmptyPath)
|
||||
}
|
||||
}
|
||||
val sourceClock = walkModule(LogicNode(rT.ref), instancePath.zip(modulePath))
|
||||
rT.moduleTarget.ref(sourceClock.name)
|
||||
}
|
||||
|
||||
def apply(state: CircuitState, queryTargets: Seq[ReferenceTarget]): Map[ReferenceTarget, ReferenceTarget] = {
|
||||
def analyze(state: CircuitState, queryTargets: Seq[ReferenceTarget]): Map[ReferenceTarget, ReferenceTarget] = {
|
||||
queryTargets.foreach(t => {
|
||||
require(t.component == Nil)
|
||||
require(t.module == t.circuit, s"Queried leaf clock ${t} must provide an absolute instance path")
|
||||
|
@ -52,6 +74,18 @@ object FindClockSources {
|
|||
val simplifiedSubgraph = subgraph.simplify((clockPorts ++ instClockNodes ++ queriedNodes).toSet)
|
||||
module -> simplifiedSubgraph
|
||||
}
|
||||
(queryTargets.map(qT => qT -> getSourceClock(clockConnectivity.toMap)(qT))).toMap
|
||||
(queryTargets.map(qT => qT -> getSourceClock(clockConnectivity.toMap, moduleDeps)(qT))).toMap
|
||||
}
|
||||
|
||||
def execute(state: CircuitState): CircuitState = {
|
||||
val queryAnnotations = state.annotations.collect({ case anno: FindClockSourceAnnotation => anno })
|
||||
val sourceMappings = analyze(state, queryAnnotations.map(_.target))
|
||||
val clockSourceAnnotations = queryAnnotations.map(qAnno =>
|
||||
ClockSourceAnnotation(qAnno.originalTarget.getOrElse(qAnno.target), sourceMappings(qAnno.target)))
|
||||
val prunedAnnos = state.annotations.flatMap({
|
||||
case _: FindClockSourceAnnotation => None
|
||||
case o => Some(o)
|
||||
})
|
||||
state.copy(annotations = clockSourceAnnotations ++ prunedAnnos)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,11 +29,10 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
private val printMods = new mutable.HashSet[ModuleTarget]()
|
||||
private val formatStringMap = new mutable.HashMap[ReferenceTarget, String]()
|
||||
|
||||
// Generates a bundle to aggregate
|
||||
// Generates a bundle containing a print's clock, enable, and argument fields
|
||||
def genPrintBundleType(print: Print): Type = BundleType(Seq(
|
||||
Field("enable", Default, BoolType)) ++
|
||||
print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) })
|
||||
)
|
||||
Field("clock", Default, ClockType), Field("enable", Default, BoolType)) ++
|
||||
print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) }))
|
||||
|
||||
def getPrintName(p: Print, anno: SynthPrintfAnnotation, ns: Namespace): String = {
|
||||
// If the user provided a name in the annotation use it; otherwise use the source locator
|
||||
|
@ -57,17 +56,17 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
|
||||
// Takes a single printPort and emits an FCCA for each field
|
||||
def genFCCAsFromPort(mT: ModuleTarget, p: Port): Seq[FAMEChannelConnectionAnnotation] = {
|
||||
println(p)
|
||||
p.tpe match {
|
||||
case BundleType(fields) =>
|
||||
fields.map(field =>
|
||||
FAMEChannelConnectionAnnotation.implicitlyClockedSource(
|
||||
case BundleType(clockField :: dataFields) =>
|
||||
dataFields.map(field =>
|
||||
FAMEChannelConnectionAnnotation.source(
|
||||
p.name + "_" + field.name,
|
||||
WireChannel,
|
||||
clock = Some(mT.ref(p.name).field(clockField.name)),
|
||||
Seq(mT.ref(p.name).field(field.name))
|
||||
)
|
||||
)
|
||||
case other => Seq()
|
||||
case other => ???
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,7 +74,10 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
require(state.annotations.collect({ case t: TopWiringAnnotation => t }).isEmpty,
|
||||
"CircuitState cannot have existing TopWiring annotations before PrintSynthesis.")
|
||||
val c = state.circuit
|
||||
|
||||
def mTarget(m: Module): ModuleTarget = ModuleTarget(c.main, m.name)
|
||||
def portRT(p: Port): ReferenceTarget = ModuleTarget(c.main, c.main).ref(p.name)
|
||||
def portClockRT(p: Port): ReferenceTarget = portRT(p).field("clock")
|
||||
|
||||
val modToAnnos = printfAnnos.groupBy(_.mod)
|
||||
|
||||
|
@ -92,12 +94,13 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
|
||||
def onStmt(annos: Seq[SynthPrintfAnnotation], modNamespace: Namespace)
|
||||
(s: Statement): Statement = s.map(onStmt(annos, modNamespace)) match {
|
||||
case p @ Print(_,format,args,_,en) if annos.exists(_.format == format.string) =>
|
||||
case p @ Print(_,format,args,clk,en) if annos.exists(_.format == format.string) =>
|
||||
val associatedAnno = annos.find(_.format == format.string).get
|
||||
val printName = getPrintName(p, associatedAnno, modNamespace)
|
||||
// Generate an aggregate with all of our arguments; this will be wired out
|
||||
val wire = DefWire(NoInfo, printName, genPrintBundleType(p))
|
||||
val enableConnect = Connect(NoInfo, wsub(WRef(wire), s"enable"), en)
|
||||
val clockConnect = Connect(NoInfo, wsub(WRef(wire), "clock"), clk)
|
||||
val enableConnect = Connect(NoInfo, wsub(WRef(wire), "enable"), en)
|
||||
val argumentConnects = (p.args.zipWithIndex).map({ case (arg, idx) =>
|
||||
Connect(NoInfo,
|
||||
wsub(WRef(wire), s"args_${idx}"),
|
||||
|
@ -106,15 +109,16 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
val printBundleTarget = associatedAnno.mod.ref(printName)
|
||||
topWiringAnnos += TopWiringAnnotation(printBundleTarget, topWiringPrefix)
|
||||
formatStringMap(printBundleTarget) = format.serialize
|
||||
Block(Seq(p, wire, enableConnect) ++ argumentConnects)
|
||||
Block(Seq(p, wire, clockConnect, enableConnect) ++ argumentConnects)
|
||||
case s => s
|
||||
}
|
||||
|
||||
// Step 1: Find and replace printfs with stubs
|
||||
val processedCircuit = c.map(onModule)
|
||||
|
||||
// Step 2: Wire out print stubs to top level module
|
||||
val wiredState = (new TopWiringTransform).execute(state.copy(
|
||||
circuit = processedCircuit,
|
||||
annotations = state.annotations ++ topWiringAnnos))
|
||||
|
||||
val topModule = wiredState.circuit.modules.find(_.name == wiredState.circuit.main).get
|
||||
val portMap: Map[String, Port] = topModule.ports.map(port => port.name -> port).toMap
|
||||
val addedPrintPorts = topLevelOutputs.map({ case ((cname,_,_,path,prefix),_) =>
|
||||
|
@ -124,31 +128,39 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends
|
|||
(port, formatString)
|
||||
})
|
||||
|
||||
// Step 3: Using each printf bundle clock, group them into separate clock domains
|
||||
val findClockSourceAnnos = addedPrintPorts.map({ case (port, _) => FindClockSourceAnnotation(portClockRT(port)) })
|
||||
val stateToAnalyze = wiredState.copy(annotations = findClockSourceAnnos ++ wiredState.annotations)
|
||||
val loweredState = Seq(new ResolveAndCheck,
|
||||
new HighFirrtlToMiddleFirrtl,
|
||||
new MiddleFirrtlToLowFirrtl,
|
||||
FindClockSources).foldLeft(stateToAnalyze)((state, xform) => xform.transform(state))
|
||||
|
||||
val clockMapping = loweredState.annotations.collect({ case ClockSourceAnnotation(qT, source) => qT -> source }).toMap
|
||||
val groupedPrints = addedPrintPorts.groupBy({ case (port, _) => clockMapping(portClockRT(port)) })
|
||||
|
||||
println(s"[MIDAS] total # of prints synthesized: ${addedPrintPorts.size}")
|
||||
|
||||
val printRecordAnno = addedPrintPorts match {
|
||||
case Nil => Seq()
|
||||
case ports => {
|
||||
// TODO: Generate sensible channel annotations once we can aggregate wire channels
|
||||
val portName = topWiringPrefix.stripSuffix("_")
|
||||
val mT = ModuleTarget(c.main, c.main)
|
||||
val portRT = mT.ref(portName)
|
||||
|
||||
val fccaAnnos = ports.flatMap({ case (port, _) => genFCCAsFromPort(mT, port) })
|
||||
val bridgeAnno = BridgeIOAnnotation(
|
||||
target = portRT,
|
||||
widget = (p: Parameters) => new PrintBridgeModule(topWiringPrefix, addedPrintPorts)(p),
|
||||
channelNames = fccaAnnos.map(_.globalName)
|
||||
)
|
||||
bridgeAnno +: fccaAnnos
|
||||
}
|
||||
// Step 4: Generate FCCAs and Bridge Annotations for each clock domain
|
||||
val printRecordAnnos = for ((clockRT, ports) <- groupedPrints) yield {
|
||||
// TODO: Generate sensible channel annotations once we can aggregate wire channels
|
||||
val portName = topWiringPrefix.stripSuffix("_")
|
||||
val mT = ModuleTarget(c.main, c.main)
|
||||
val portRT = mT.ref(portName)
|
||||
val fccaAnnos = ports.flatMap({ case (port, _) => genFCCAsFromPort(mT, port) })
|
||||
val bridgeAnno = BridgeIOAnnotation(
|
||||
target = portRT,
|
||||
widget = (p: Parameters) => new PrintBridgeModule(topWiringPrefix, ports)(p),
|
||||
channelNames = fccaAnnos.map(_.globalName)
|
||||
)
|
||||
bridgeAnno +: fccaAnnos
|
||||
}
|
||||
// Remove added TopWiringAnnotations to prevent being reconsumed by a downstream pass
|
||||
val cleanedAnnotations = wiredState.annotations.flatMap({
|
||||
case TopWiringAnnotation(_,_) => None
|
||||
case otherAnno => Some(otherAnno)
|
||||
})
|
||||
wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnno)
|
||||
wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnnos.toSeq.flatten)
|
||||
}
|
||||
|
||||
def execute(state: CircuitState): CircuitState = {
|
||||
|
|
|
@ -17,11 +17,13 @@ class PrintRecord(portType: firrtl.ir.BundleType, val formatString: String) exte
|
|||
def regenLeafType(tpe: firrtl.ir.Type): Data = tpe match {
|
||||
case firrtl.ir.UIntType(width: firrtl.ir.IntWidth) => UInt(width.width.toInt.W)
|
||||
case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W)
|
||||
case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W)
|
||||
case badType => throw new RuntimeException(s"Unexpected type in PrintBundle: ${badType}")
|
||||
}
|
||||
|
||||
val args: Seq[(String, Data)] = portType.fields.collect({
|
||||
case firrtl.ir.Field(name, _, tpe) if name != "enable" => (name -> Output(regenLeafType(tpe)))
|
||||
case firrtl.ir.Field(name, _, tpe) if name != "enable" && name != "clock" =>
|
||||
(name -> Output(regenLeafType(tpe)))
|
||||
})
|
||||
|
||||
val enable = Output(Bool())
|
||||
|
|
|
@ -16,7 +16,7 @@ class DefaultF1Config extends Config(new Config((site, here, up) => {
|
|||
case DesiredHostFrequency => 75
|
||||
case SynthAsserts => true
|
||||
case midas.GenerateMultiCycleRamModels => true
|
||||
case SynthPrints => false
|
||||
case SynthPrints => true
|
||||
}) ++ new Config(new firesim.configs.WithEC2F1Artefacts ++ new WithDefaultMemModel ++ new midas.F1Config))
|
||||
|
||||
class PointerChaserConfig extends Config((site, here, up) => {
|
||||
|
|
Loading…
Reference in New Issue