refactor: remove replace_subindex

This commit is contained in:
raybdzhou 2022-06-10 19:33:59 +08:00
parent 8b1b849df5
commit db3333abda
5 changed files with 66 additions and 83 deletions

View File

@ -17,8 +17,8 @@ class FullAdder(Module):
if __name__ == '__main__':
# emit high firrtl
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
# Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
# emit lowered firrtl
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
# Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
# emit verilog
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")

View File

@ -4,12 +4,12 @@ from dataclasses import dataclass
from pyhcl.ir import low_ir
from pyhcl.dsl.check_and_infer import CheckAndInfer
from pyhcl.passes.replace_subaccess import ReplaceSubaccess
from pyhcl.passes.replace_subindex import ReplaceSubindex
from pyhcl.passes.expand_aggregate import ExpandAggregate
from pyhcl.passes.expand_whens import ExpandWhens
from pyhcl.passes.expand_memory import ExpandMemory
from pyhcl.passes.optimize import Optimize
from pyhcl.passes.verilog_optimize import VerilogOptimize
from pyhcl.passes.remove_access import RemoveAccess
from pyhcl.passes.utils import AutoName
class Form(ABC):
@ -41,7 +41,7 @@ class LowForm(Form):
self.c = ReplaceSubaccess().run(self.c)
self.c = ExpandAggregate().run(self.c)
self.c = ExpandWhens().run(self.c)
self.c = ReplaceSubindex().run(self.c)
self.c = RemoveAccess().run(self.c)
self.c = Optimize().run(self.c)
return self.c.serialize()
@ -53,7 +53,7 @@ class Verilog(Form):
self.c = CheckAndInfer.run(self.c)
self.c = ExpandAggregate().run(self.c)
self.c = ReplaceSubaccess().run(self.c)
self.c = ReplaceSubindex().run(self.c)
self.c = RemoveAccess().run(self.c)
self.c = VerilogOptimize().run(self.c)
self.c = Optimize().run(self.c)
return self.c.verilog_serialize()

View File

@ -0,0 +1,60 @@
from pyhcl.ir.low_ir import *
from typing import List
from dataclasses import dataclass
from pyhcl.passes._pass import Pass
@dataclass
class RemoveAccess(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[Module] = []
def remove_access(e: Expression, type: Type = None, name: str = None) -> Expression:
if isinstance(e, Reference):
return Reference(f"{e.name}{name}", type)
elif isinstance(e, SubIndex):
return remove_access(e.expr, type, f"_{e.value}")
elif isinstance(e, SubField):
return remove_access(e.expr, type, f"_{e.name}")
else:
return e
def remove_access_e(e: Expression) -> Expression:
if isinstance(e, (SubIndex, SubField)):
return remove_access(e, e.typ)
elif isinstance(e, ValidIf):
return ValidIf(remove_access_e(e.cond), remove_access_e(e.value), e.typ)
elif isinstance(e, Mux):
return Mux(remove_access_e(e.cond), remove_access_e(e.tval), remove_access_e(e.fval), e.typ)
elif isinstance(e, DoPrim):
return DoPrim(e.op, [remove_access_e(arg) for arg in e.args], e.consts, e.typ)
else:
return e
def remove_access_s(s: Statement) -> Statement:
if isinstance(s, Block):
stmts: List[Statement] = []
for sx in s.stmts:
stmts.append(remove_access_s(sx))
return Block(stmts)
elif isinstance(s, Conditionally):
return Conditionally(remove_access_e(s.pred), remove_access_s(s.conseq), remove_access_s(s.alt), s.info)
elif isinstance(s, DefRegister):
return DefRegister(s.name, s.typ, remove_access_e(s.clock), remove_access_e(s.reset), remove_access_e(s.init), s.info)
elif isinstance(s, DefNode):
return DefNode(s.name, remove_access_e(s.value), s.info)
elif isinstance(s, DefMemPort):
return DefMemPort(s.name, s.mem, remove_access_e(s.index), remove_access_e(s.clk), s.rw, s.info)
elif isinstance(s, Connect):
return Connect(remove_access_e(s.loc), remove_access_e(s.expr), s.info)
else:
return s
def remove_access_m(m: DefModule) -> DefModule:
if isinstance(m, Module):
return Module(m.name, m.ports, remove_access_s(m.body), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(remove_access_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -1,77 +0,0 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
@dataclass
class ReplaceSubindex(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[DefModule] = []
keep_subfield: List[str] = []
def get_name(e: Expression) -> str:
if isinstance(e, (SubField, SubIndex, SubAccess)):
return get_name(e.expr)
else:
return e.name
def replace_subindex_e(e: Expression) -> Expression:
if isinstance(e, SubIndex):
return Reference(e.verilog_serialize(), e.typ)
if isinstance(e, SubField):
if get_name(e) in keep_subfield:
return e
return Reference(e.verilog_serialize(), e.typ)
if isinstance(e, SubAccess):
return SubAccess(replace_subindex_e(e.expr), replace_subindex_e(e.index), e.typ)
if isinstance(e, ValidIf):
return ValidIf(e.cond, replace_subindex_e(e.value), e.typ)
if isinstance(e, Mux):
return Mux(e.cond, replace_subindex_e(e.tval), replace_subindex_e(e.fval), e.typ)
if isinstance(e, DoPrim):
return DoPrim(e.op, list(map(lambda ex: replace_subindex_e(ex), e.args)), e.consts, e.typ)
return e
def replace_subindex(stmt: Statement) -> Statement:
if isinstance(stmt, Connect):
return Connect(replace_subindex_e(stmt.loc), replace_subindex_e(stmt.expr),
stmt.info, stmt.blocking, stmt.bidirection, stmt.mem)
elif isinstance(stmt, DefNode):
return DefNode(stmt.name, replace_subindex_e(stmt.value), stmt.info)
elif isinstance(stmt, DefRegister):
return DefRegister(stmt.name, stmt.typ, replace_subindex_e(stmt.clock),
replace_subindex_e(stmt.reset), replace_subindex_e(stmt.init), stmt.info)
elif isinstance(stmt, DefMemPort):
keep_subfield.append(stmt.mem.name)
return DefMemPort(stmt.name, stmt.mem, replace_subindex_e(stmt.index),
replace_subindex_e(stmt.clk), stmt.rw, stmt.info)
elif isinstance(stmt, Conditionally):
return Conditionally(replace_subindex_e(stmt.pred), replace_subindex(stmt.conseq),
replace_subindex(stmt.alt), stmt.info)
elif isinstance(stmt, Block):
return Block(replace_subindex_s(stmt.stmts))
elif isinstance(stmt, DefInstance):
keep_subfield.append(stmt.name)
return stmt
else:
return stmt
def replace_subindex_s(stmts: List[Statement]) -> List[Statement]:
new_stmts = []
for stmt in stmts:
new_stmts.append(replace_subindex(stmt))
return new_stmts
def replace_subindex_m(m: DefModule) -> DefModule:
if isinstance(m, ExtModule):
return m
if not hasattr(m, 'body') or not isinstance(m.body, Block):
return m
if not hasattr(m.body, 'stmts') or not isinstance(m.body.stmts, list):
return m
return Module(m.name, m.ports, Block(replace_subindex_s(m.body.stmts)), m.typ, m.info)
for m in c.modules:
modules.append(replace_subindex_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -68,7 +68,7 @@ class VerilogOptimize(Pass):
return stmt
def verilog_optimize_m(m: Module) -> Module:
def verilog_optimize_m(m: DefModule) -> DefModule:
node_map: Dict[str, DefNode] = {}
filter_nodes: set = set()
if isinstance(m, Module):