[gg] Bring up multiclock assertion synthesis; reenable AS in examples

This commit is contained in:
David Biancolin 2019-11-23 20:13:50 -08:00
parent df94490101
commit 2643c32d53
3 changed files with 139 additions and 58 deletions

View File

@ -4,7 +4,7 @@ package passes
import java.io.{File, FileWriter, Writer}
import firrtl._
import firrtl.annotations.{CircuitName, ModuleName, ComponentName, ModuleTarget}
import firrtl.annotations._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.WrappedExpression._
@ -23,16 +23,19 @@ private[passes] class AssertPass(
(implicit p: Parameters) extends firrtl.Transform {
def inputForm = LowForm
def outputForm = HighForm
override def name = "[MIDAS] Assertion Synthesis"
override def name = "[Golden Gate] Assertion Synthesis"
type Asserts = collection.mutable.HashMap[String, (Int, String)]
type Asserts = collection.mutable.HashMap[String, (Int, String, Expression)]
type Messages = collection.mutable.HashMap[Int, String]
private val asserts = collection.mutable.HashMap[String, Asserts]()
private val messages = collection.mutable.HashMap[String, Messages]()
private val assertPorts = collection.mutable.HashMap[String, Port]()
private val assertPorts = collection.mutable.HashMap[String, (Port, Port)]()
private val excludeInstAsserts = collection.mutable.HashSet[(String, String)]()
private val clockVectorName = "midasAssertsClocks"
private def getNameOrCroak(ns: Namespace, name: String): String = { assert(ns.tryName(name)); name }
// Helper method to filter out module instances
private def excludeInst(excludes: Seq[(String, String)])
(parentMod: String, inst: String): Boolean =
@ -44,11 +47,11 @@ private[passes] class AssertPass(
namespace: Namespace)
(s: Statement): Statement =
s map synAsserts(mname, namespace) match {
case s: Stop if s.ret != 0 && !weq(s.en, zero) =>
case Stop(info, ret, clk, en) if ret != 0 && !weq(en, zero) =>
val idx = asserts(mname).size
val name = namespace newName s"assert_$idx"
asserts(mname)(s.en.serialize) = idx -> name
DefNode(s.info, name, s.en)
asserts(mname)(en.serialize) = (idx, name, clk)
DefNode(info, name, en)
case s => s
}
@ -56,12 +59,12 @@ private[passes] class AssertPass(
private def findMessages(mname: String)
(s: Statement): Statement =
s map findMessages(mname) match {
// work around a large design assert build failure
// work around a large design assert build failure
// drop arguments and just show the format string
//case s: Print if s.args.isEmpty =>
case s: Print =>
asserts(mname) get s.en.serialize match {
case Some((idx, str)) =>
case Some((idx, str, _)) =>
messages(mname)(idx) = s.string.serialize
EmptyStmt
case _ => s
@ -69,17 +72,12 @@ private[passes] class AssertPass(
case s => s
}
private def transform(meta: StroberMetaData)
(m: DefModule): DefModule = {
val namespace = Namespace(m)
asserts(m.name) = new Asserts
messages(m.name) = new Messages
def getChildren(ports: collection.mutable.Map[String, Port],
private class ModuleAssertInfo(m: Module, meta: StroberMetaData) {
def getChildren(ports: collection.mutable.Map[String, (Port, Port)],
instExcludes: collection.mutable.HashSet[(String, String)]) = {
(meta.childInsts(m.name)
.filterNot(instName => excludeInst(instExcludes.toSeq)(m.name, instName))
.foldRight(Seq[(String, Port)]())((x, res) =>
.foldRight(Seq[(String, (Port, Port))]())((x, res) =>
ports get meta.instModMap(x -> m.name) match {
case None => res
case Some(p) => res :+ (x -> p)
@ -88,32 +86,70 @@ private[passes] class AssertPass(
)
}
(m map synAsserts(m.name, namespace)
map findMessages(m.name)) match {
case m: Module =>
val ports = collection.mutable.ArrayBuffer[Port]()
val stmts = collection.mutable.ArrayBuffer[Statement]()
// Connect asserts
val assertChildren = getChildren(assertPorts, excludeInstAsserts)
val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)(
(res, x) => res + firrtl.bitWidth(x._2.tpe).toInt))
if (assertWidth > 0) {
val tpe = UIntType(IntWidth(assertWidth))
val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe)
val stmt = Connect(NoInfo, WRef(port.name), cat(
(assertChildren map (x => wsub(wref(x._1), x._2.name))) ++
(asserts(m.name).values.toSeq sortWith (_._1 > _._1) map (x => wref(x._2)))))
assertPorts(m.name) = port
ports += port
stmts += stmt
}
m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq))
case m: ExtModule => m
}
val assertChildren = getChildren(assertPorts, excludeInstAsserts)
val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)({
case (sum, (_, (assertP, _))) => sum + firrtl.bitWidth(assertP.tpe).toInt
}))
def hasAsserts(): Boolean = assertWidth > 0
// Get references to assertion ports on all child instances
lazy val (childAsserts, childAssertClocks): (Seq[WSubField], Seq[Seq[WSubIndex]]) =
(for ((childInstName, (assertPort, clockPort)) <- assertChildren) yield {
val childWidth = firrtl.bitWidth(assertPort.tpe).toInt
val assertRef = wsub(wref(childInstName), assertPort.name)
val clockRefs = Seq.tabulate(childWidth)(i => widx(wsub(wref(childInstName), clockPort.name), i))
(assertRef, clockRefs)
}).unzip
// Get references to all module-local synthesized assertions
val sortedLocalAsserts = asserts(m.name).values.toSeq.sortWith (_._1 > _._1)
val (localAsserts, localClocks) =
sortedLocalAsserts.map({ case (_, en, clk) => (wref(en), clk) }).unzip
def allAsserts = childAsserts ++ localAsserts
def allClocks = childAssertClocks.flatten ++ localClocks
def assertUInt = UIntType(IntWidth(assertWidth))
}
private def replaceStopsAndFindMessages(m: DefModule): DefModule = m match {
case m: Module =>
val namespace = Namespace(m)
asserts(m.name) = new Asserts
messages(m.name) = new Messages
m.map(synAsserts(m.name, namespace))
.map(findMessages(m.name))
case m: ExtModule => m
}
private def wireSynthesizedAssertions(meta: StroberMetaData)
(m: DefModule): DefModule = m match {
case m: Module =>
val ports = collection.mutable.ArrayBuffer[Port]()
val stmts = collection.mutable.ArrayBuffer[Statement]()
val mInfo = new ModuleAssertInfo(m, meta)
// Connect asserts
if (mInfo.hasAsserts) {
val namespace = Namespace(m)
val tpe = mInfo.assertUInt
val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe)
val clockType = VectorType(ClockType, mInfo.assertWidth)
val clockPort = Port(NoInfo, getNameOrCroak(namespace, clockVectorName), Output, clockType)
val assertConnect = Connect(NoInfo, WRef(port.name), cat(mInfo.allAsserts))
val clockConnects = for ((clock, idx) <- mInfo.allClocks.zipWithIndex) yield {
Connect(NoInfo, widx(WRef(clockPort.name), idx), clock)
}
assertPorts(m.name) = (port, clockPort)
ports ++= Seq(port, clockPort)
stmts ++= (assertConnect +: clockConnects)
}
m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq))
case m: ExtModule => m
}
private var assertNum = 0
def dump(writer: Writer, meta: StroberMetaData, mod: String, path: String) {
asserts(mod).values.toSeq sortWith (_._1 < _._1) foreach { case (idx, _) =>
asserts(mod).values.toSeq sortWith (_._1 < _._1) foreach { case (idx, _, _) =>
writer write s"[id: $assertNum, module: $mod, path: $path]\n"
writer write (messages(mod)(idx) replace ("""\n""", "\n"))
writer write "0\n"
@ -125,41 +161,85 @@ private[passes] class AssertPass(
}
def synthesizeAsserts(state: CircuitState): CircuitState = {
val c = state.circuit
// Step 1: Grab module-based exclusions
state.annotations.collect {
case a @ (_: ExcludeInstanceAssertsAnnotation) => excludeInstAsserts += a.target
}
// Step 2: Replace stop statements (asserts) and find associated message
val c = state.circuit.copy(modules = state.circuit.modules.map(replaceStopsAndFindMessages))
val topModule = c.modules.collectFirst({ case m: Module if m.name == c.main => m }).get
val namespace = Namespace(topModule)
// Step 3: Wiring assertions and associated clocks to the top-level
val meta = StroberMetaData(c)
val mods = postorder(c, meta)(transform(meta))
val mods = postorder(c, meta)(wireSynthesizedAssertions(meta))
val f = new FileWriter(new File(dir, s"${c.main}.asserts"))
dump(f, meta, c.main, c.main)
f.close
println(s"[MIDAS] total # of assertions synthesized: $assertNum")
// Step 4: Use wired-clocks to associate assertions with particular input clocks
val postWiredState = state.copy(circuit = c.copy(modules = mods), form = MidForm)
val loweredState = Seq(new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl).foldLeft(postWiredState)((state, xform) => xform.transform(state))
val connectivity = (new firrtl.transforms.CheckCombLoops).analyze(loweredState)(c.main)
val mainModulePortMap = loweredState.circuit.modules
.collectFirst({ case m if m.name == c.main => m }).get
.ports.map(p => p.name -> p).toMap
// For each clock in clock channel, list associated assert indices
val groupedAsserts = Seq.tabulate(assertNum)(i => i)
.groupBy({ idx =>
val name = s"${clockVectorName}_$idx"
val srcClockPorts = connectivity.getEdges(name).map(mainModulePortMap(_))
assert(srcClockPorts.size == 1)
srcClockPorts.head
})
// Step 5: Re-wire the top-level module
val ports = collection.mutable.ArrayBuffer[Port]()
val stmts = collection.mutable.ArrayBuffer[Statement]()
val assertAnnos = collection.mutable.ArrayBuffer[Annotation]()
val mInfo = new ModuleAssertInfo(topModule, meta)
// Step 5a: Connect all assertions to a single wire to match the order of our previous analysis
val allAssertsWire = DefWire(NoInfo, "allAsserts", mInfo.assertUInt)
val allAssertConnect = Connect(NoInfo, WRef(allAssertsWire), cat(mInfo.allAsserts))
stmts ++= Seq(allAssertsWire, allAssertConnect)
// Step 5b: Generate unique ports for each clock
for ((clockName, asserts) <- groupedAsserts) {
val portName = namespace.newName(s"midasAsserts_${clockName.name}")
val clockPortName = namespace.newName(s"midasAsserts_${clockName.name}_clock")
val tpe = UIntType(IntWidth(asserts.size))
val port = Port(NoInfo, portName, Output, tpe)
val clockPort = Port(NoInfo, clockPortName, Output, ClockType)
ports ++= Seq(port, clockPort)
val bitExtracts = asserts.map(idx => DoPrim(PrimOps.Bits, Seq(WRef(allAssertsWire)), Seq(idx, idx), UIntType(IntWidth(1))))
val connectAsserts = Connect(NoInfo, WRef(port), cat(bitExtracts))
val connectClock = Connect(NoInfo, WRef(clockPort), WRef(clockName))
stmts ++= Seq(connectClock, connectAsserts)
val mName = ModuleName(c.main, CircuitName(c.main))
val assertAnnos = if (assertNum > 0) {
val portName = assertPorts(c.main).name
val portRT = ModuleTarget(c.main, c.main).ref(portName)
val fcca = FAMEChannelConnectionAnnotation.implicitlyClockedSource(portName, WireChannel, Seq(portRT))
val clockPortRT = ModuleTarget(c.main, c.main).ref(clockPortName)
val fcca = FAMEChannelConnectionAnnotation.source(portName, WireChannel, Some(clockPortRT), Seq(portRT))
val bridgeAnno = BridgeIOAnnotation(
target = portRT,
widget = Some((p: Parameters) => new AssertBridgeModule(assertNum)(p)),
channelMapping = Map("" -> portName)
)
Seq(fcca, bridgeAnno)
} else {
Seq()
assertAnnos ++= Seq(fcca, bridgeAnno)
}
val wiredTopModule = topModule.copy(ports = topModule.ports ++ ports,
body = Block(topModule.body +: stmts.toSeq))
println(s"[Golden Gate] total # of assertions synthesized: $assertNum")
state.copy(
circuit = c.copy(modules = mods),
circuit = c.copy(modules = wiredTopModule +: mods.filterNot(_.name == c.main)),
form = HighForm,
annotations = state.annotations ++ assertAnnos)
annotations = state.annotations ++ assertAnnos
)
}
def execute(state: CircuitState): CircuitState = {

View File

@ -44,6 +44,10 @@ private[midas] class MidasTransforms(
firrtl.passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination,
EnsureNoTargetIO,
new BridgeExtraction,
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
// NB: Carelessly removing this pass will break the FireSim manager as we always
// need to generate the *.asserts file. Fix by baking into driver.
new AssertPass(dir),
@ -51,10 +55,7 @@ private[midas] class MidasTransforms(
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new MiddleFirrtlToLowFirrtl,
new BridgeExtraction,
new ResolveAndCheck,
new EmitFirrtl("post-bridge-extraction.fir"),
new MiddleFirrtlToLowFirrtl,
fame.WrapTop,
new ResolveAndCheck,
new EmitFirrtl("post-wrap-top.fir")) ++

View File

@ -14,7 +14,7 @@ class NoConfig extends Config(Parameters.empty)
// This is incomplete and must be mixed into a complete platform config
class DefaultF1Config extends Config(new Config((site, here, up) => {
case DesiredHostFrequency => 75
case SynthAsserts => false
case SynthAsserts => true
case midas.GenerateMultiCycleRamModels => true
case SynthPrints => false
}) ++ new Config(new firesim.configs.WithEC2F1Artefacts ++ new WithDefaultMemModel ++ new midas.F1Config))