[gg] Bring up multiclock assertion synthesis; reenable AS in examples
This commit is contained in:
parent
df94490101
commit
2643c32d53
|
@ -4,7 +4,7 @@ package passes
|
|||
import java.io.{File, FileWriter, Writer}
|
||||
|
||||
import firrtl._
|
||||
import firrtl.annotations.{CircuitName, ModuleName, ComponentName, ModuleTarget}
|
||||
import firrtl.annotations._
|
||||
import firrtl.ir._
|
||||
import firrtl.Mappers._
|
||||
import firrtl.WrappedExpression._
|
||||
|
@ -23,16 +23,19 @@ private[passes] class AssertPass(
|
|||
(implicit p: Parameters) extends firrtl.Transform {
|
||||
def inputForm = LowForm
|
||||
def outputForm = HighForm
|
||||
override def name = "[MIDAS] Assertion Synthesis"
|
||||
override def name = "[Golden Gate] Assertion Synthesis"
|
||||
|
||||
type Asserts = collection.mutable.HashMap[String, (Int, String)]
|
||||
type Asserts = collection.mutable.HashMap[String, (Int, String, Expression)]
|
||||
type Messages = collection.mutable.HashMap[Int, String]
|
||||
|
||||
private val asserts = collection.mutable.HashMap[String, Asserts]()
|
||||
private val messages = collection.mutable.HashMap[String, Messages]()
|
||||
private val assertPorts = collection.mutable.HashMap[String, Port]()
|
||||
private val assertPorts = collection.mutable.HashMap[String, (Port, Port)]()
|
||||
private val excludeInstAsserts = collection.mutable.HashSet[(String, String)]()
|
||||
|
||||
private val clockVectorName = "midasAssertsClocks"
|
||||
private def getNameOrCroak(ns: Namespace, name: String): String = { assert(ns.tryName(name)); name }
|
||||
|
||||
// Helper method to filter out module instances
|
||||
private def excludeInst(excludes: Seq[(String, String)])
|
||||
(parentMod: String, inst: String): Boolean =
|
||||
|
@ -44,11 +47,11 @@ private[passes] class AssertPass(
|
|||
namespace: Namespace)
|
||||
(s: Statement): Statement =
|
||||
s map synAsserts(mname, namespace) match {
|
||||
case s: Stop if s.ret != 0 && !weq(s.en, zero) =>
|
||||
case Stop(info, ret, clk, en) if ret != 0 && !weq(en, zero) =>
|
||||
val idx = asserts(mname).size
|
||||
val name = namespace newName s"assert_$idx"
|
||||
asserts(mname)(s.en.serialize) = idx -> name
|
||||
DefNode(s.info, name, s.en)
|
||||
asserts(mname)(en.serialize) = (idx, name, clk)
|
||||
DefNode(info, name, en)
|
||||
case s => s
|
||||
}
|
||||
|
||||
|
@ -56,12 +59,12 @@ private[passes] class AssertPass(
|
|||
private def findMessages(mname: String)
|
||||
(s: Statement): Statement =
|
||||
s map findMessages(mname) match {
|
||||
// work around a large design assert build failure
|
||||
// work around a large design assert build failure
|
||||
// drop arguments and just show the format string
|
||||
//case s: Print if s.args.isEmpty =>
|
||||
case s: Print =>
|
||||
asserts(mname) get s.en.serialize match {
|
||||
case Some((idx, str)) =>
|
||||
case Some((idx, str, _)) =>
|
||||
messages(mname)(idx) = s.string.serialize
|
||||
EmptyStmt
|
||||
case _ => s
|
||||
|
@ -69,17 +72,12 @@ private[passes] class AssertPass(
|
|||
case s => s
|
||||
}
|
||||
|
||||
private def transform(meta: StroberMetaData)
|
||||
(m: DefModule): DefModule = {
|
||||
val namespace = Namespace(m)
|
||||
asserts(m.name) = new Asserts
|
||||
messages(m.name) = new Messages
|
||||
|
||||
def getChildren(ports: collection.mutable.Map[String, Port],
|
||||
private class ModuleAssertInfo(m: Module, meta: StroberMetaData) {
|
||||
def getChildren(ports: collection.mutable.Map[String, (Port, Port)],
|
||||
instExcludes: collection.mutable.HashSet[(String, String)]) = {
|
||||
(meta.childInsts(m.name)
|
||||
.filterNot(instName => excludeInst(instExcludes.toSeq)(m.name, instName))
|
||||
.foldRight(Seq[(String, Port)]())((x, res) =>
|
||||
.foldRight(Seq[(String, (Port, Port))]())((x, res) =>
|
||||
ports get meta.instModMap(x -> m.name) match {
|
||||
case None => res
|
||||
case Some(p) => res :+ (x -> p)
|
||||
|
@ -88,32 +86,70 @@ private[passes] class AssertPass(
|
|||
)
|
||||
}
|
||||
|
||||
(m map synAsserts(m.name, namespace)
|
||||
map findMessages(m.name)) match {
|
||||
case m: Module =>
|
||||
val ports = collection.mutable.ArrayBuffer[Port]()
|
||||
val stmts = collection.mutable.ArrayBuffer[Statement]()
|
||||
// Connect asserts
|
||||
val assertChildren = getChildren(assertPorts, excludeInstAsserts)
|
||||
val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)(
|
||||
(res, x) => res + firrtl.bitWidth(x._2.tpe).toInt))
|
||||
if (assertWidth > 0) {
|
||||
val tpe = UIntType(IntWidth(assertWidth))
|
||||
val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe)
|
||||
val stmt = Connect(NoInfo, WRef(port.name), cat(
|
||||
(assertChildren map (x => wsub(wref(x._1), x._2.name))) ++
|
||||
(asserts(m.name).values.toSeq sortWith (_._1 > _._1) map (x => wref(x._2)))))
|
||||
assertPorts(m.name) = port
|
||||
ports += port
|
||||
stmts += stmt
|
||||
}
|
||||
m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq))
|
||||
case m: ExtModule => m
|
||||
}
|
||||
val assertChildren = getChildren(assertPorts, excludeInstAsserts)
|
||||
val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)({
|
||||
case (sum, (_, (assertP, _))) => sum + firrtl.bitWidth(assertP.tpe).toInt
|
||||
}))
|
||||
|
||||
def hasAsserts(): Boolean = assertWidth > 0
|
||||
|
||||
// Get references to assertion ports on all child instances
|
||||
lazy val (childAsserts, childAssertClocks): (Seq[WSubField], Seq[Seq[WSubIndex]]) =
|
||||
(for ((childInstName, (assertPort, clockPort)) <- assertChildren) yield {
|
||||
val childWidth = firrtl.bitWidth(assertPort.tpe).toInt
|
||||
val assertRef = wsub(wref(childInstName), assertPort.name)
|
||||
val clockRefs = Seq.tabulate(childWidth)(i => widx(wsub(wref(childInstName), clockPort.name), i))
|
||||
(assertRef, clockRefs)
|
||||
}).unzip
|
||||
|
||||
// Get references to all module-local synthesized assertions
|
||||
val sortedLocalAsserts = asserts(m.name).values.toSeq.sortWith (_._1 > _._1)
|
||||
val (localAsserts, localClocks) =
|
||||
sortedLocalAsserts.map({ case (_, en, clk) => (wref(en), clk) }).unzip
|
||||
|
||||
def allAsserts = childAsserts ++ localAsserts
|
||||
def allClocks = childAssertClocks.flatten ++ localClocks
|
||||
def assertUInt = UIntType(IntWidth(assertWidth))
|
||||
}
|
||||
|
||||
private def replaceStopsAndFindMessages(m: DefModule): DefModule = m match {
|
||||
case m: Module =>
|
||||
val namespace = Namespace(m)
|
||||
asserts(m.name) = new Asserts
|
||||
messages(m.name) = new Messages
|
||||
m.map(synAsserts(m.name, namespace))
|
||||
.map(findMessages(m.name))
|
||||
case m: ExtModule => m
|
||||
}
|
||||
|
||||
private def wireSynthesizedAssertions(meta: StroberMetaData)
|
||||
(m: DefModule): DefModule = m match {
|
||||
case m: Module =>
|
||||
val ports = collection.mutable.ArrayBuffer[Port]()
|
||||
val stmts = collection.mutable.ArrayBuffer[Statement]()
|
||||
val mInfo = new ModuleAssertInfo(m, meta)
|
||||
// Connect asserts
|
||||
if (mInfo.hasAsserts) {
|
||||
val namespace = Namespace(m)
|
||||
val tpe = mInfo.assertUInt
|
||||
val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe)
|
||||
val clockType = VectorType(ClockType, mInfo.assertWidth)
|
||||
val clockPort = Port(NoInfo, getNameOrCroak(namespace, clockVectorName), Output, clockType)
|
||||
val assertConnect = Connect(NoInfo, WRef(port.name), cat(mInfo.allAsserts))
|
||||
val clockConnects = for ((clock, idx) <- mInfo.allClocks.zipWithIndex) yield {
|
||||
Connect(NoInfo, widx(WRef(clockPort.name), idx), clock)
|
||||
}
|
||||
assertPorts(m.name) = (port, clockPort)
|
||||
ports ++= Seq(port, clockPort)
|
||||
stmts ++= (assertConnect +: clockConnects)
|
||||
}
|
||||
m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq))
|
||||
case m: ExtModule => m
|
||||
}
|
||||
|
||||
private var assertNum = 0
|
||||
def dump(writer: Writer, meta: StroberMetaData, mod: String, path: String) {
|
||||
asserts(mod).values.toSeq sortWith (_._1 < _._1) foreach { case (idx, _) =>
|
||||
asserts(mod).values.toSeq sortWith (_._1 < _._1) foreach { case (idx, _, _) =>
|
||||
writer write s"[id: $assertNum, module: $mod, path: $path]\n"
|
||||
writer write (messages(mod)(idx) replace ("""\n""", "\n"))
|
||||
writer write "0\n"
|
||||
|
@ -125,41 +161,85 @@ private[passes] class AssertPass(
|
|||
|
||||
}
|
||||
def synthesizeAsserts(state: CircuitState): CircuitState = {
|
||||
val c = state.circuit
|
||||
|
||||
// Step 1: Grab module-based exclusions
|
||||
state.annotations.collect {
|
||||
case a @ (_: ExcludeInstanceAssertsAnnotation) => excludeInstAsserts += a.target
|
||||
}
|
||||
|
||||
// Step 2: Replace stop statements (asserts) and find associated message
|
||||
val c = state.circuit.copy(modules = state.circuit.modules.map(replaceStopsAndFindMessages))
|
||||
val topModule = c.modules.collectFirst({ case m: Module if m.name == c.main => m }).get
|
||||
val namespace = Namespace(topModule)
|
||||
|
||||
// Step 3: Wiring assertions and associated clocks to the top-level
|
||||
val meta = StroberMetaData(c)
|
||||
val mods = postorder(c, meta)(transform(meta))
|
||||
val mods = postorder(c, meta)(wireSynthesizedAssertions(meta))
|
||||
val f = new FileWriter(new File(dir, s"${c.main}.asserts"))
|
||||
dump(f, meta, c.main, c.main)
|
||||
f.close
|
||||
|
||||
println(s"[MIDAS] total # of assertions synthesized: $assertNum")
|
||||
// Step 4: Use wired-clocks to associate assertions with particular input clocks
|
||||
val postWiredState = state.copy(circuit = c.copy(modules = mods), form = MidForm)
|
||||
val loweredState = Seq(new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl).foldLeft(postWiredState)((state, xform) => xform.transform(state))
|
||||
val connectivity = (new firrtl.transforms.CheckCombLoops).analyze(loweredState)(c.main)
|
||||
val mainModulePortMap = loweredState.circuit.modules
|
||||
.collectFirst({ case m if m.name == c.main => m }).get
|
||||
.ports.map(p => p.name -> p).toMap
|
||||
|
||||
// For each clock in clock channel, list associated assert indices
|
||||
val groupedAsserts = Seq.tabulate(assertNum)(i => i)
|
||||
.groupBy({ idx =>
|
||||
val name = s"${clockVectorName}_$idx"
|
||||
val srcClockPorts = connectivity.getEdges(name).map(mainModulePortMap(_))
|
||||
assert(srcClockPorts.size == 1)
|
||||
srcClockPorts.head
|
||||
})
|
||||
|
||||
// Step 5: Re-wire the top-level module
|
||||
val ports = collection.mutable.ArrayBuffer[Port]()
|
||||
val stmts = collection.mutable.ArrayBuffer[Statement]()
|
||||
val assertAnnos = collection.mutable.ArrayBuffer[Annotation]()
|
||||
val mInfo = new ModuleAssertInfo(topModule, meta)
|
||||
|
||||
// Step 5a: Connect all assertions to a single wire to match the order of our previous analysis
|
||||
val allAssertsWire = DefWire(NoInfo, "allAsserts", mInfo.assertUInt)
|
||||
val allAssertConnect = Connect(NoInfo, WRef(allAssertsWire), cat(mInfo.allAsserts))
|
||||
stmts ++= Seq(allAssertsWire, allAssertConnect)
|
||||
|
||||
// Step 5b: Generate unique ports for each clock
|
||||
for ((clockName, asserts) <- groupedAsserts) {
|
||||
val portName = namespace.newName(s"midasAsserts_${clockName.name}")
|
||||
val clockPortName = namespace.newName(s"midasAsserts_${clockName.name}_clock")
|
||||
val tpe = UIntType(IntWidth(asserts.size))
|
||||
val port = Port(NoInfo, portName, Output, tpe)
|
||||
val clockPort = Port(NoInfo, clockPortName, Output, ClockType)
|
||||
ports ++= Seq(port, clockPort)
|
||||
val bitExtracts = asserts.map(idx => DoPrim(PrimOps.Bits, Seq(WRef(allAssertsWire)), Seq(idx, idx), UIntType(IntWidth(1))))
|
||||
val connectAsserts = Connect(NoInfo, WRef(port), cat(bitExtracts))
|
||||
val connectClock = Connect(NoInfo, WRef(clockPort), WRef(clockName))
|
||||
stmts ++= Seq(connectClock, connectAsserts)
|
||||
|
||||
val mName = ModuleName(c.main, CircuitName(c.main))
|
||||
val assertAnnos = if (assertNum > 0) {
|
||||
val portName = assertPorts(c.main).name
|
||||
val portRT = ModuleTarget(c.main, c.main).ref(portName)
|
||||
val fcca = FAMEChannelConnectionAnnotation.implicitlyClockedSource(portName, WireChannel, Seq(portRT))
|
||||
|
||||
val clockPortRT = ModuleTarget(c.main, c.main).ref(clockPortName)
|
||||
val fcca = FAMEChannelConnectionAnnotation.source(portName, WireChannel, Some(clockPortRT), Seq(portRT))
|
||||
val bridgeAnno = BridgeIOAnnotation(
|
||||
target = portRT,
|
||||
widget = Some((p: Parameters) => new AssertBridgeModule(assertNum)(p)),
|
||||
channelMapping = Map("" -> portName)
|
||||
)
|
||||
|
||||
Seq(fcca, bridgeAnno)
|
||||
} else {
|
||||
Seq()
|
||||
assertAnnos ++= Seq(fcca, bridgeAnno)
|
||||
}
|
||||
val wiredTopModule = topModule.copy(ports = topModule.ports ++ ports,
|
||||
body = Block(topModule.body +: stmts.toSeq))
|
||||
|
||||
println(s"[Golden Gate] total # of assertions synthesized: $assertNum")
|
||||
|
||||
state.copy(
|
||||
circuit = c.copy(modules = mods),
|
||||
circuit = c.copy(modules = wiredTopModule +: mods.filterNot(_.name == c.main)),
|
||||
form = HighForm,
|
||||
annotations = state.annotations ++ assertAnnos)
|
||||
annotations = state.annotations ++ assertAnnos
|
||||
)
|
||||
}
|
||||
|
||||
def execute(state: CircuitState): CircuitState = {
|
||||
|
|
|
@ -44,6 +44,10 @@ private[midas] class MidasTransforms(
|
|||
firrtl.passes.CommonSubexpressionElimination,
|
||||
new firrtl.transforms.DeadCodeElimination,
|
||||
EnsureNoTargetIO,
|
||||
new BridgeExtraction,
|
||||
new ResolveAndCheck,
|
||||
new HighFirrtlToMiddleFirrtl,
|
||||
new MiddleFirrtlToLowFirrtl,
|
||||
// NB: Carelessly removing this pass will break the FireSim manager as we always
|
||||
// need to generate the *.asserts file. Fix by baking into driver.
|
||||
new AssertPass(dir),
|
||||
|
@ -51,10 +55,7 @@ private[midas] class MidasTransforms(
|
|||
new ResolveAndCheck,
|
||||
new HighFirrtlToMiddleFirrtl,
|
||||
new MiddleFirrtlToLowFirrtl,
|
||||
new BridgeExtraction,
|
||||
new ResolveAndCheck,
|
||||
new EmitFirrtl("post-bridge-extraction.fir"),
|
||||
new MiddleFirrtlToLowFirrtl,
|
||||
fame.WrapTop,
|
||||
new ResolveAndCheck,
|
||||
new EmitFirrtl("post-wrap-top.fir")) ++
|
||||
|
|
|
@ -14,7 +14,7 @@ class NoConfig extends Config(Parameters.empty)
|
|||
// This is incomplete and must be mixed into a complete platform config
|
||||
class DefaultF1Config extends Config(new Config((site, here, up) => {
|
||||
case DesiredHostFrequency => 75
|
||||
case SynthAsserts => false
|
||||
case SynthAsserts => true
|
||||
case midas.GenerateMultiCycleRamModels => true
|
||||
case SynthPrints => false
|
||||
}) ++ new Config(new firesim.configs.WithEC2F1Artefacts ++ new WithDefaultMemModel ++ new midas.F1Config))
|
||||
|
|
Loading…
Reference in New Issue