Port multi-port serialized mem optimization to multi-clock

This commit is contained in:
Albert Magyar 2020-05-28 03:17:16 +00:00
parent 3143253dda
commit 2e635759f4
9 changed files with 102 additions and 46 deletions

View File

@ -6,11 +6,12 @@ import midas.widgets.{BridgeAnnotation, ClockBridgeAnnotation}
import midas.passes.fame.{PromoteSubmodule, PromoteSubmoduleAnnotation, FAMEChannelConnectionAnnotation}
import firrtl._
import firrtl.annotations._
import firrtl.ir._
import firrtl.passes.{InferTypes, ResolveKinds}
import firrtl.traversals.Foreachers._
import firrtl.transforms.TopWiring.{TopWiringAnnotation, TopWiringTransform, TopWiringOutputFilesAnnotation}
import firrtl.passes.wiring.{Wiring, WiringInfo}
import firrtl.annotations._
import Utils._
import scala.collection.mutable
@ -129,8 +130,13 @@ private[passes] class BridgeExtraction extends firrtl.Transform {
// Annotate all bridge instances
val instAnnoedState = annotateInstances(wrappedTopState)
def normalize(state: CircuitState): CircuitState = {
val cx = RemoveTrivialPartialConnects.run(InferTypes.run(ResolveKinds.run(state.circuit)))
state.copy(circuit = cx)
}
// Promote all modules that are annotated as bridges
val promotedState = promoteBridges(instAnnoedState)
val promotedState = normalize(promoteBridges(instAnnoedState))
// Propogate bridge annotations to the IO created on the true top module
val commutedState = commuteBridgeAnnotations(promotedState)

View File

@ -0,0 +1,27 @@
// See LICENSE for license details.
package midas.passes
import firrtl._
import firrtl.ir._
import firrtl.passes._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.options.{Dependency, PreservesAll}
object RemoveTrivialPartialConnects extends Pass with PreservesAll[Transform] {
override def prerequisites =
Dependency(InferTypes) +: Dependency(ResolveKinds) +: stage.Forms.WorkingIR
private def onStmt(stmt: Statement): Statement = stmt match {
case PartialConnect(i, l, e) if (l.tpe == e.tpe) => Connect(i, l, e)
case s => s.map(onStmt)
}
private def onModule(m: DefModule): DefModule = m.map(onStmt)
def run(c: Circuit): Circuit = {
c.map(onModule)
}
}

View File

@ -17,7 +17,6 @@ import firrtl.Mappers._
import firrtl.passes.LowerTypes.loweredName
import firrtl.Utils.{BoolType, splitRef, mergeRef, create_exps, flow, module_type}
import firrtl.passes.wiring._
import fame.{FAMEChannelConnectionAnnotation, FAMEChannelPortsAnnotation, FAMEChannelAnalysis, FAME1Transform}
import Utils._
import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.LazyModule
@ -117,7 +116,7 @@ private[passes] class SimulationMapping(targetName: String) extends firrtl.Trans
Map(CircuitName(innerCircuit.main) -> Seq(CircuitName(outerCircuit.main))))
val innerAnnos = loweredInnerState.annotations.filter(_ match {
case _: FAMEChannelConnectionAnnotation | _: FAMEChannelPortsAnnotation => false
case _: midas.targetutils.FAMEAnnotation => false
case _: BridgeIOAnnotation => false
case _ => true
})

View File

@ -185,15 +185,15 @@ class RAMModelInst(name: String, val readPorts: Seq[ReadPort], val writePorts: S
) ++ readConnects ++ writeConnects
}
def elaborateModel(parentCircuitNS: Namespace): (DefModule, Seq[Annotation]) = {
def elaborateModel(parentCircuitNS: Namespace): Module = {
val c3circuit = chisel3.Driver.elaborate(() =>
new midas.models.sram.AsyncMemChiselModel(depth.toInt, dataWidth, readPorts.size, writePorts.size))
val chirrtl = Parser.parse(chisel3.Driver.emit(c3circuit))
val annos = PreLinkRenamingAnnotation(parentCircuitNS) +: c3circuit.annotations.map(_.toFirrtl)
val state = new MiddleFirrtlCompiler().compile(CircuitState(chirrtl, ChirrtlForm, annos), Seq(PreLinkRenaming))
require(state.circuit.modules.size == 1)
val mod = state.circuit.modules.head
(mod, state.annotations)
val state = new MiddleFirrtlCompiler().compile(CircuitState(chirrtl, ChirrtlForm, Nil), Nil)
require(state.circuit.modules.length == 1)
state.circuit.modules.collectFirst({
case m: Module => m.copy(name = name)
}).get
}
}
@ -225,21 +225,13 @@ class EmitAndWrapRAMModels extends Transform {
val readWritePorts = annos.collect({ case anno: ModelReadWritePort => new ReadWritePort(anno, mod.ports)})
require(readWritePorts.isEmpty)
val (hostClock, updatedPortList) = mod.ports.find(p => p.tpe == ClockType && p.name == "clock") match {
case Some(port) => (port, mod.ports)
case None =>
val cPort = Port(NoInfo, "clock", Input, ClockType)
(cPort, cPort +: mod.ports)
}
val hostReset = mod.ports.find(_.name == "hostReset").get
val hostClock = mod.ports.find(_.name == WrapTop.hostClockName).get
val hostReset = mod.ports.find(_.name == WrapTop.hostResetName).get
val name = ns.newName("RamModel")
val inst = new RAMModelInst(name, readPorts, writePorts)
val (module, newAnnos) = inst.elaborateModel(Namespace(c))
addedModules += module
addedAnnotations ++= newAnnos
Module(NoInfo, mod.name, updatedPortList, Block(inst.emitStatements(WRef(hostClock), WRef(hostReset))))
addedModules += inst.elaborateModel(Namespace(c))
Module(NoInfo, mod.name, mod.ports, Block(inst.emitStatements(WRef(hostClock), WRef(hostReset))))
}
def onModule(mod: DefModule): DefModule = mod match {
@ -249,11 +241,6 @@ class EmitAndWrapRAMModels extends Transform {
val newCircuit = state.circuit.map(onModule)
val renames = RenameMap.create(addedModules.map(m =>
ModuleTarget(m.name, m.name) -> Seq(CircuitTarget(c.main))).toMap)
state.copy(circuit = newCircuit.copy(modules = newCircuit.modules ++ addedModules),
annotations = state.annotations ++ addedAnnotations,
renames = Some(renames))
state.copy(circuit = newCircuit.copy(modules = newCircuit.modules ++ addedModules))
}
}

View File

@ -5,14 +5,15 @@ package midas.passes.fame
import java.io.{PrintWriter, File}
import firrtl._
import ir._
import Mappers._
import Utils._
import firrtl.ir._
import firrtl.Utils._
import firrtl.passes.MemPortUtils
import firrtl.analyses.InstanceGraph
import annotations.{InstanceTarget, Annotation, SingleTargetAnnotation}
import midas.targetutils.FirrtlFAMEModelAnnotation
import scala.annotation.tailrec
import scala.collection.mutable
import mutable.{LinkedHashSet, LinkedHashMap}
@ -20,15 +21,20 @@ class ExtractModel extends Transform {
def inputForm = HighForm
def outputForm = HighForm
def promoteModels(state: CircuitState): CircuitState = {
val anns = state.annotations.flatMap {
case a @ FirrtlFAMEModelAnnotation(it) if (it.module != it.circuit) => Seq(a, PromoteSubmoduleAnnotation(it))
case a => Seq(a)
}
if (anns.toSeq == state.annotations.toSeq) {
state
} else {
promoteModels((new PromoteSubmodule).runTransform(state.copy(annotations = anns)))
@tailrec
private def promoteModels(state: CircuitState): CircuitState = {
val fmAnns = state.annotations.collect({
case ann: FirrtlFAMEModelAnnotation => ann.target.module -> ann
}).toMap
// Pick by order of parent in linearization -- don't pick children of main
val modOrder = (new InstanceGraph(state.circuit)).moduleOrder.filterNot(_.name == state.circuit.main)
val topModelAnn = modOrder.collectFirst(Function.unlift(dm => fmAnns.get(dm.name)))
topModelAnn match {
case None => state
case Some(FirrtlFAMEModelAnnotation(it)) =>
val anns = PromoteSubmoduleAnnotation(it) +: state.annotations
val nextPromoted = (new PromoteSubmodule).runTransform(state.copy(annotations = anns))
promoteModels(nextPromoted)
}
}

View File

@ -307,9 +307,14 @@ class FAMETransform extends Transform {
renames
}
def staleTopPort(p: Port, analysis: FAMEChannelAnalysis): Boolean = p match {
case Port(_, name, _, ClockType) => name != WrapTop.hostClockName
case Port(_, name, _, _) => analysis.staleTopPorts.contains(analysis.topTarget.ref(name))
}
def transformTop(top: DefModule, analysis: FAMEChannelAnalysis): Module = top match {
case Module(info, name, ports, body) =>
val transformedPorts = ports.filterNot(p => analysis.staleTopPorts.contains(analysis.topTarget.ref(p.name))) ++
val transformedPorts = ports.filterNot(p => staleTopPort(p, analysis)) ++
analysis.transformedSinks.map(c => Port(NoInfo, s"${c}_sink", Input, analysis.getSinkHostDecoupledChannelType(c))) ++
analysis.transformedSources.map(c => Port(NoInfo, s"${c}_source", Output, analysis.getSourceHostDecoupledChannelType(c)))
val transformedStmts = Seq(body.map(updateNonChannelConnects(analysis))) ++

View File

@ -16,10 +16,18 @@ class LabelSRAMModels extends Transform {
val confwriter = new passes.memlib.ConfWriter("NONE")
val memutil = new passes.memlib.ReplaceMemMacros(confwriter)
// Wrapper gets converted to strip clocks from ports, has one top-level clock
def mem2Module(mem: DefMemory): Module = {
val ports = passes.MemPortUtils.memType(mem).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
val connects = ports.map(p => Connect(NoInfo, WSubField(WRef(mem.name), p.name), WRef(p.name)))
Module(mem.info, mem.name, ports, Block(mem +: connects))
val clockPort = Port(NoInfo, "clk", Input, ClockType)
def stripClocks(tpe: Type): Type = tpe match {
case BundleType(fields) => BundleType(fields.filterNot(_.tpe == ClockType))
}
val ports = mem.readers ++ mem.writers ++ mem.readwriters
val connects = ports.map(p => PartialConnect(NoInfo, WSubField(WRef(mem.name), p), WRef(p)))
val clkConnects = ports.map(p => Connect(NoInfo, WSubField(WSubField(WRef(mem.name), p), "clk"), WRef(clockPort)))
val modPorts = clockPort +: passes.MemPortUtils.memType(mem).fields.map(f => Port(NoInfo, f.name, Input, stripClocks(f.tpe)))
Module(mem.info, mem.name, modPorts, Block(mem +: connects ++: clkConnects))
}
override def execute(state: CircuitState): CircuitState = {
@ -46,6 +54,11 @@ class LabelSRAMModels extends Transform {
memModelAnnotations ++= mem.writers.map(rp => ModelWritePort(wrapperTarget.ref(rp)))
memModelAnnotations ++= mem.readwriters.map(rp => ModelReadWritePort(wrapperTarget.ref(rp)))
WDefInstance(mem.info, mem.name, wrapper.name, UnknownType)
case c: Connect if (Utils.kind(c.loc) == MemKind && c.loc.tpe == ClockType) =>
// change clock connects to target single mem wrapper clock
val (wr, e) = Utils.splitRef(c.loc)
val wrapperClock = Utils.mergeRef(wr, Utils.splitRef(e)._2)
if (annotatedMems.contains(mt.ref(wr.name))) c.copy(loc = wrapperClock) else c
case s => s
}
m.copy(body = m.body.map(onStmt))

View File

@ -70,11 +70,12 @@ class PromoteSubmodule extends Transform {
}
private def transformParentInstances(stmt: Statement, parentTemplate: WDefInstance, childTemplate: WDefInstance, namespace: Namespace, promotedNames: mutable.ArrayBuffer[String]): Statement = stmt match {
// TODO: this does not handle instances inside of whens
case oldParentInstance @ WDefInstance(_, _, parentTemplate.module, _) =>
val retypedParentInst = oldParentInstance.copy(tpe = parentTemplate.tpe)
val childPeerInst = childTemplate.copy(name = namespace.newName(oldParentInstance.name + "_" + childTemplate.name))
promotedNames += childPeerInst.name
val connection = Connect(childTemplate.info, instanceField(retypedParentInst, childTemplate.name), instanceRef(childPeerInst))
val connection = PartialConnect(childTemplate.info, instanceField(retypedParentInst, childTemplate.name), instanceRef(childPeerInst))
Block(Seq(retypedParentInst, childPeerInst, connection))
case Block(stmts) => Block(stmts map (s => transformParentInstances(s, parentTemplate, childTemplate, namespace, promotedNames)))
case s => s
@ -95,7 +96,7 @@ class PromoteSubmodule extends Transform {
val parentModule = updatedModules(parentInstances.head.module)
val originalTarget = CircuitTarget(state.circuit.main).module(parentModule.name).instOf(childInstance.name, childInstance.module)
if (parentModule.name == state.circuit.main) {
throw new PassException("Cannot promote child instance ${childInstance.name} from top module ${parentModule.name}")
throw new PassException(s"Cannot promote child instance ${childInstance.name} from top module ${parentModule.name}")
}
updatedModules(parentModule.name) = instanceToPort(parentModule, childInstance, childModule)
val grandparentInstances = parentInstances.flatMap(reversedIGraph.getEdges(_))

View File

@ -65,6 +65,10 @@ object Or extends BinaryBooleanOp {
val op = PrimOps.Or
}
object Xor extends BinaryBooleanOp {
val op = PrimOps.Xor
}
object Neq extends BinaryBooleanOp {
val op = PrimOps.Neq
}
@ -75,4 +79,12 @@ object Neq extends BinaryBooleanOp {
object RegZeroPreset {
def apply(info: Info, name: String, tpe: Type, clock: Expression): DefRegister =
DefRegister(info, name, tpe, clock, zero, WRef(name))
def apply(info: Info, name: String, clock: Expression): DefRegister =
DefRegister(info, name, BoolType, clock, zero, WRef(name))
}
object ConditionalConnect {
def apply(cond: Expression, lhs: Expression, rhs: Expression): Conditionally = {
Conditionally(NoInfo, cond, Connect(NoInfo, lhs, rhs), EmptyStmt)
}
}