Port multi-port serialized mem optimization to multi-clock
This commit is contained in:
parent
3143253dda
commit
2e635759f4
|
@ -6,11 +6,12 @@ import midas.widgets.{BridgeAnnotation, ClockBridgeAnnotation}
|
|||
import midas.passes.fame.{PromoteSubmodule, PromoteSubmoduleAnnotation, FAMEChannelConnectionAnnotation}
|
||||
|
||||
import firrtl._
|
||||
import firrtl.annotations._
|
||||
import firrtl.ir._
|
||||
import firrtl.passes.{InferTypes, ResolveKinds}
|
||||
import firrtl.traversals.Foreachers._
|
||||
import firrtl.transforms.TopWiring.{TopWiringAnnotation, TopWiringTransform, TopWiringOutputFilesAnnotation}
|
||||
import firrtl.passes.wiring.{Wiring, WiringInfo}
|
||||
import firrtl.annotations._
|
||||
import Utils._
|
||||
|
||||
import scala.collection.mutable
|
||||
|
@ -129,8 +130,13 @@ private[passes] class BridgeExtraction extends firrtl.Transform {
|
|||
// Annotate all bridge instances
|
||||
val instAnnoedState = annotateInstances(wrappedTopState)
|
||||
|
||||
def normalize(state: CircuitState): CircuitState = {
|
||||
val cx = RemoveTrivialPartialConnects.run(InferTypes.run(ResolveKinds.run(state.circuit)))
|
||||
state.copy(circuit = cx)
|
||||
}
|
||||
|
||||
// Promote all modules that are annotated as bridges
|
||||
val promotedState = promoteBridges(instAnnoedState)
|
||||
val promotedState = normalize(promoteBridges(instAnnoedState))
|
||||
|
||||
// Propogate bridge annotations to the IO created on the true top module
|
||||
val commutedState = commuteBridgeAnnotations(promotedState)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// See LICENSE for license details.
|
||||
|
||||
package midas.passes
|
||||
|
||||
import firrtl._
|
||||
import firrtl.ir._
|
||||
import firrtl.passes._
|
||||
import firrtl.Utils._
|
||||
import firrtl.Mappers._
|
||||
import firrtl.options.{Dependency, PreservesAll}
|
||||
|
||||
object RemoveTrivialPartialConnects extends Pass with PreservesAll[Transform] {
|
||||
|
||||
override def prerequisites =
|
||||
Dependency(InferTypes) +: Dependency(ResolveKinds) +: stage.Forms.WorkingIR
|
||||
|
||||
private def onStmt(stmt: Statement): Statement = stmt match {
|
||||
case PartialConnect(i, l, e) if (l.tpe == e.tpe) => Connect(i, l, e)
|
||||
case s => s.map(onStmt)
|
||||
}
|
||||
|
||||
private def onModule(m: DefModule): DefModule = m.map(onStmt)
|
||||
|
||||
def run(c: Circuit): Circuit = {
|
||||
c.map(onModule)
|
||||
}
|
||||
}
|
|
@ -17,7 +17,6 @@ import firrtl.Mappers._
|
|||
import firrtl.passes.LowerTypes.loweredName
|
||||
import firrtl.Utils.{BoolType, splitRef, mergeRef, create_exps, flow, module_type}
|
||||
import firrtl.passes.wiring._
|
||||
import fame.{FAMEChannelConnectionAnnotation, FAMEChannelPortsAnnotation, FAMEChannelAnalysis, FAME1Transform}
|
||||
import Utils._
|
||||
import freechips.rocketchip.config.Parameters
|
||||
import freechips.rocketchip.diplomacy.LazyModule
|
||||
|
@ -117,7 +116,7 @@ private[passes] class SimulationMapping(targetName: String) extends firrtl.Trans
|
|||
Map(CircuitName(innerCircuit.main) -> Seq(CircuitName(outerCircuit.main))))
|
||||
|
||||
val innerAnnos = loweredInnerState.annotations.filter(_ match {
|
||||
case _: FAMEChannelConnectionAnnotation | _: FAMEChannelPortsAnnotation => false
|
||||
case _: midas.targetutils.FAMEAnnotation => false
|
||||
case _: BridgeIOAnnotation => false
|
||||
case _ => true
|
||||
})
|
||||
|
|
|
@ -185,15 +185,15 @@ class RAMModelInst(name: String, val readPorts: Seq[ReadPort], val writePorts: S
|
|||
) ++ readConnects ++ writeConnects
|
||||
}
|
||||
|
||||
def elaborateModel(parentCircuitNS: Namespace): (DefModule, Seq[Annotation]) = {
|
||||
def elaborateModel(parentCircuitNS: Namespace): Module = {
|
||||
val c3circuit = chisel3.Driver.elaborate(() =>
|
||||
new midas.models.sram.AsyncMemChiselModel(depth.toInt, dataWidth, readPorts.size, writePorts.size))
|
||||
val chirrtl = Parser.parse(chisel3.Driver.emit(c3circuit))
|
||||
val annos = PreLinkRenamingAnnotation(parentCircuitNS) +: c3circuit.annotations.map(_.toFirrtl)
|
||||
val state = new MiddleFirrtlCompiler().compile(CircuitState(chirrtl, ChirrtlForm, annos), Seq(PreLinkRenaming))
|
||||
require(state.circuit.modules.size == 1)
|
||||
val mod = state.circuit.modules.head
|
||||
(mod, state.annotations)
|
||||
val state = new MiddleFirrtlCompiler().compile(CircuitState(chirrtl, ChirrtlForm, Nil), Nil)
|
||||
require(state.circuit.modules.length == 1)
|
||||
state.circuit.modules.collectFirst({
|
||||
case m: Module => m.copy(name = name)
|
||||
}).get
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,21 +225,13 @@ class EmitAndWrapRAMModels extends Transform {
|
|||
val readWritePorts = annos.collect({ case anno: ModelReadWritePort => new ReadWritePort(anno, mod.ports)})
|
||||
require(readWritePorts.isEmpty)
|
||||
|
||||
val (hostClock, updatedPortList) = mod.ports.find(p => p.tpe == ClockType && p.name == "clock") match {
|
||||
case Some(port) => (port, mod.ports)
|
||||
case None =>
|
||||
val cPort = Port(NoInfo, "clock", Input, ClockType)
|
||||
(cPort, cPort +: mod.ports)
|
||||
}
|
||||
|
||||
val hostReset = mod.ports.find(_.name == "hostReset").get
|
||||
val hostClock = mod.ports.find(_.name == WrapTop.hostClockName).get
|
||||
val hostReset = mod.ports.find(_.name == WrapTop.hostResetName).get
|
||||
|
||||
val name = ns.newName("RamModel")
|
||||
val inst = new RAMModelInst(name, readPorts, writePorts)
|
||||
val (module, newAnnos) = inst.elaborateModel(Namespace(c))
|
||||
addedModules += module
|
||||
addedAnnotations ++= newAnnos
|
||||
Module(NoInfo, mod.name, updatedPortList, Block(inst.emitStatements(WRef(hostClock), WRef(hostReset))))
|
||||
addedModules += inst.elaborateModel(Namespace(c))
|
||||
Module(NoInfo, mod.name, mod.ports, Block(inst.emitStatements(WRef(hostClock), WRef(hostReset))))
|
||||
}
|
||||
|
||||
def onModule(mod: DefModule): DefModule = mod match {
|
||||
|
@ -249,11 +241,6 @@ class EmitAndWrapRAMModels extends Transform {
|
|||
|
||||
val newCircuit = state.circuit.map(onModule)
|
||||
|
||||
val renames = RenameMap.create(addedModules.map(m =>
|
||||
ModuleTarget(m.name, m.name) -> Seq(CircuitTarget(c.main))).toMap)
|
||||
|
||||
state.copy(circuit = newCircuit.copy(modules = newCircuit.modules ++ addedModules),
|
||||
annotations = state.annotations ++ addedAnnotations,
|
||||
renames = Some(renames))
|
||||
state.copy(circuit = newCircuit.copy(modules = newCircuit.modules ++ addedModules))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,14 +5,15 @@ package midas.passes.fame
|
|||
import java.io.{PrintWriter, File}
|
||||
|
||||
import firrtl._
|
||||
import ir._
|
||||
import Mappers._
|
||||
import Utils._
|
||||
import firrtl.ir._
|
||||
import firrtl.Utils._
|
||||
import firrtl.passes.MemPortUtils
|
||||
import firrtl.analyses.InstanceGraph
|
||||
import annotations.{InstanceTarget, Annotation, SingleTargetAnnotation}
|
||||
|
||||
import midas.targetutils.FirrtlFAMEModelAnnotation
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable
|
||||
import mutable.{LinkedHashSet, LinkedHashMap}
|
||||
|
||||
|
@ -20,15 +21,20 @@ class ExtractModel extends Transform {
|
|||
def inputForm = HighForm
|
||||
def outputForm = HighForm
|
||||
|
||||
def promoteModels(state: CircuitState): CircuitState = {
|
||||
val anns = state.annotations.flatMap {
|
||||
case a @ FirrtlFAMEModelAnnotation(it) if (it.module != it.circuit) => Seq(a, PromoteSubmoduleAnnotation(it))
|
||||
case a => Seq(a)
|
||||
}
|
||||
if (anns.toSeq == state.annotations.toSeq) {
|
||||
state
|
||||
} else {
|
||||
promoteModels((new PromoteSubmodule).runTransform(state.copy(annotations = anns)))
|
||||
@tailrec
|
||||
private def promoteModels(state: CircuitState): CircuitState = {
|
||||
val fmAnns = state.annotations.collect({
|
||||
case ann: FirrtlFAMEModelAnnotation => ann.target.module -> ann
|
||||
}).toMap
|
||||
// Pick by order of parent in linearization -- don't pick children of main
|
||||
val modOrder = (new InstanceGraph(state.circuit)).moduleOrder.filterNot(_.name == state.circuit.main)
|
||||
val topModelAnn = modOrder.collectFirst(Function.unlift(dm => fmAnns.get(dm.name)))
|
||||
topModelAnn match {
|
||||
case None => state
|
||||
case Some(FirrtlFAMEModelAnnotation(it)) =>
|
||||
val anns = PromoteSubmoduleAnnotation(it) +: state.annotations
|
||||
val nextPromoted = (new PromoteSubmodule).runTransform(state.copy(annotations = anns))
|
||||
promoteModels(nextPromoted)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -307,9 +307,14 @@ class FAMETransform extends Transform {
|
|||
renames
|
||||
}
|
||||
|
||||
def staleTopPort(p: Port, analysis: FAMEChannelAnalysis): Boolean = p match {
|
||||
case Port(_, name, _, ClockType) => name != WrapTop.hostClockName
|
||||
case Port(_, name, _, _) => analysis.staleTopPorts.contains(analysis.topTarget.ref(name))
|
||||
}
|
||||
|
||||
def transformTop(top: DefModule, analysis: FAMEChannelAnalysis): Module = top match {
|
||||
case Module(info, name, ports, body) =>
|
||||
val transformedPorts = ports.filterNot(p => analysis.staleTopPorts.contains(analysis.topTarget.ref(p.name))) ++
|
||||
val transformedPorts = ports.filterNot(p => staleTopPort(p, analysis)) ++
|
||||
analysis.transformedSinks.map(c => Port(NoInfo, s"${c}_sink", Input, analysis.getSinkHostDecoupledChannelType(c))) ++
|
||||
analysis.transformedSources.map(c => Port(NoInfo, s"${c}_source", Output, analysis.getSourceHostDecoupledChannelType(c)))
|
||||
val transformedStmts = Seq(body.map(updateNonChannelConnects(analysis))) ++
|
||||
|
|
|
@ -16,10 +16,18 @@ class LabelSRAMModels extends Transform {
|
|||
val confwriter = new passes.memlib.ConfWriter("NONE")
|
||||
val memutil = new passes.memlib.ReplaceMemMacros(confwriter)
|
||||
|
||||
|
||||
// Wrapper gets converted to strip clocks from ports, has one top-level clock
|
||||
def mem2Module(mem: DefMemory): Module = {
|
||||
val ports = passes.MemPortUtils.memType(mem).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
|
||||
val connects = ports.map(p => Connect(NoInfo, WSubField(WRef(mem.name), p.name), WRef(p.name)))
|
||||
Module(mem.info, mem.name, ports, Block(mem +: connects))
|
||||
val clockPort = Port(NoInfo, "clk", Input, ClockType)
|
||||
def stripClocks(tpe: Type): Type = tpe match {
|
||||
case BundleType(fields) => BundleType(fields.filterNot(_.tpe == ClockType))
|
||||
}
|
||||
val ports = mem.readers ++ mem.writers ++ mem.readwriters
|
||||
val connects = ports.map(p => PartialConnect(NoInfo, WSubField(WRef(mem.name), p), WRef(p)))
|
||||
val clkConnects = ports.map(p => Connect(NoInfo, WSubField(WSubField(WRef(mem.name), p), "clk"), WRef(clockPort)))
|
||||
val modPorts = clockPort +: passes.MemPortUtils.memType(mem).fields.map(f => Port(NoInfo, f.name, Input, stripClocks(f.tpe)))
|
||||
Module(mem.info, mem.name, modPorts, Block(mem +: connects ++: clkConnects))
|
||||
}
|
||||
|
||||
override def execute(state: CircuitState): CircuitState = {
|
||||
|
@ -46,6 +54,11 @@ class LabelSRAMModels extends Transform {
|
|||
memModelAnnotations ++= mem.writers.map(rp => ModelWritePort(wrapperTarget.ref(rp)))
|
||||
memModelAnnotations ++= mem.readwriters.map(rp => ModelReadWritePort(wrapperTarget.ref(rp)))
|
||||
WDefInstance(mem.info, mem.name, wrapper.name, UnknownType)
|
||||
case c: Connect if (Utils.kind(c.loc) == MemKind && c.loc.tpe == ClockType) =>
|
||||
// change clock connects to target single mem wrapper clock
|
||||
val (wr, e) = Utils.splitRef(c.loc)
|
||||
val wrapperClock = Utils.mergeRef(wr, Utils.splitRef(e)._2)
|
||||
if (annotatedMems.contains(mt.ref(wr.name))) c.copy(loc = wrapperClock) else c
|
||||
case s => s
|
||||
}
|
||||
m.copy(body = m.body.map(onStmt))
|
||||
|
|
|
@ -70,11 +70,12 @@ class PromoteSubmodule extends Transform {
|
|||
}
|
||||
|
||||
private def transformParentInstances(stmt: Statement, parentTemplate: WDefInstance, childTemplate: WDefInstance, namespace: Namespace, promotedNames: mutable.ArrayBuffer[String]): Statement = stmt match {
|
||||
// TODO: this does not handle instances inside of whens
|
||||
case oldParentInstance @ WDefInstance(_, _, parentTemplate.module, _) =>
|
||||
val retypedParentInst = oldParentInstance.copy(tpe = parentTemplate.tpe)
|
||||
val childPeerInst = childTemplate.copy(name = namespace.newName(oldParentInstance.name + "_" + childTemplate.name))
|
||||
promotedNames += childPeerInst.name
|
||||
val connection = Connect(childTemplate.info, instanceField(retypedParentInst, childTemplate.name), instanceRef(childPeerInst))
|
||||
val connection = PartialConnect(childTemplate.info, instanceField(retypedParentInst, childTemplate.name), instanceRef(childPeerInst))
|
||||
Block(Seq(retypedParentInst, childPeerInst, connection))
|
||||
case Block(stmts) => Block(stmts map (s => transformParentInstances(s, parentTemplate, childTemplate, namespace, promotedNames)))
|
||||
case s => s
|
||||
|
@ -95,7 +96,7 @@ class PromoteSubmodule extends Transform {
|
|||
val parentModule = updatedModules(parentInstances.head.module)
|
||||
val originalTarget = CircuitTarget(state.circuit.main).module(parentModule.name).instOf(childInstance.name, childInstance.module)
|
||||
if (parentModule.name == state.circuit.main) {
|
||||
throw new PassException("Cannot promote child instance ${childInstance.name} from top module ${parentModule.name}")
|
||||
throw new PassException(s"Cannot promote child instance ${childInstance.name} from top module ${parentModule.name}")
|
||||
}
|
||||
updatedModules(parentModule.name) = instanceToPort(parentModule, childInstance, childModule)
|
||||
val grandparentInstances = parentInstances.flatMap(reversedIGraph.getEdges(_))
|
||||
|
|
|
@ -65,6 +65,10 @@ object Or extends BinaryBooleanOp {
|
|||
val op = PrimOps.Or
|
||||
}
|
||||
|
||||
object Xor extends BinaryBooleanOp {
|
||||
val op = PrimOps.Xor
|
||||
}
|
||||
|
||||
object Neq extends BinaryBooleanOp {
|
||||
val op = PrimOps.Neq
|
||||
}
|
||||
|
@ -75,4 +79,12 @@ object Neq extends BinaryBooleanOp {
|
|||
object RegZeroPreset {
|
||||
def apply(info: Info, name: String, tpe: Type, clock: Expression): DefRegister =
|
||||
DefRegister(info, name, tpe, clock, zero, WRef(name))
|
||||
def apply(info: Info, name: String, clock: Expression): DefRegister =
|
||||
DefRegister(info, name, BoolType, clock, zero, WRef(name))
|
||||
}
|
||||
|
||||
object ConditionalConnect {
|
||||
def apply(cond: Expression, lhs: Expression, rhs: Expression): Conditionally = {
|
||||
Conditionally(NoInfo, cond, Connect(NoInfo, lhs, rhs), EmptyStmt)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue