forked from opendacs/PyHCL
122 lines
6.5 KiB
Python
122 lines
6.5 KiB
Python
from pyhcl.ir.low_ir import *
|
|
from pyhcl.ir.low_prim import *
|
|
from typing import List, Dict
|
|
from dataclasses import dataclass
|
|
from pyhcl.passes._pass import Pass
|
|
|
|
@dataclass
|
|
class ExpandSequential(Pass):
|
|
|
|
def run(self, c: Circuit):
|
|
modules: List[Module] = []
|
|
blocks: List[Statement] = []
|
|
block_map: Dict[str, List[Statement]] = {}
|
|
clock_map: Dict[str, Expression] = {}
|
|
|
|
def get_ref_name(e: Expression) -> str:
|
|
if isinstance(e, SubAccess):
|
|
return get_ref_name(e.expr)
|
|
elif isinstance(e, SubField):
|
|
return get_ref_name(e.expr)
|
|
elif isinstance(e, SubIndex):
|
|
return get_ref_name(e.expr)
|
|
elif isinstance(e, Reference):
|
|
return e.name
|
|
|
|
def expand_sequential_s(s: Statement, stmts: List[Statement], reg_map: Dict[str, DefRegister]):
|
|
if isinstance(s, Conditionally):
|
|
conseq_seq_map, conseq_com = expand_sequential_s(s.conseq, stmts, reg_map)
|
|
alt_seq_map, alt_com = expand_sequential_s(s.alt, stmts, reg_map)
|
|
for k in conseq_seq_map:
|
|
if k not in alt_seq_map:
|
|
alt_seq_map[k] = EmptyStmt()
|
|
com = Conditionally(s.pred, conseq_com, alt_com, s.info)
|
|
if isinstance(conseq_com, EmptyStmt):
|
|
if isinstance(alt_com, EmptyStmt):
|
|
com = EmptyStmt()
|
|
else:
|
|
com = Conditionally(DoPrim(Not(), [s.pred], [], s.pred.typ))
|
|
return {k: Conditionally(s.pred, v, alt_seq_map[k], s.info) for k, v in conseq_seq_map.items()}, com
|
|
elif isinstance(s, Block):
|
|
com_stmts: List[Statement] = []
|
|
seq_stmts_map: Dict[str, List[Statement]] = {}
|
|
for sx in s.stmts:
|
|
if isinstance(sx, Connect) and get_ref_name(sx.loc) in reg_map:
|
|
reg = reg_map[get_ref_name(sx.loc)]
|
|
if reg.clock.verilog_serialize() not in clock_map:
|
|
clock_map[reg.clock.verilog_serialize()] = reg.clock
|
|
if reg.clock.verilog_serialize() not in seq_stmts_map:
|
|
seq_stmts_map[reg.clock.verilog_serialize()] = []
|
|
seq_stmts_map[reg.clock.verilog_serialize()].append(Connect(sx.loc, sx.expr, sx.info, False, sx.bidirection, sx.mem))
|
|
elif isinstance(sx, Connect) and get_ref_name(sx.loc) not in reg_map:
|
|
com_stmts.append(sx)
|
|
elif isinstance(sx, Conditionally):
|
|
seq_when_map, com_when = expand_sequential_s(sx, stmts, reg_map)
|
|
for k in seq_when_map:
|
|
if k not in seq_stmts_map:
|
|
seq_stmts_map[k] = []
|
|
seq_stmts_map[k].append(seq_when_map[k])
|
|
com_stmts.append(com_when)
|
|
else:
|
|
stmts.append(sx)
|
|
return {k: Block(v) if len(v) > 0 else EmptyStmt() for k, v in seq_stmts_map.items()}, \
|
|
Block(com_stmts) if len(com_stmts) > 0 else EmptyStmt()
|
|
else:
|
|
return {}, s
|
|
|
|
def expand_sequential(stmts: List[Statement]) -> List[Statement]:
|
|
reg_map: Dict[str, DefRegister] = {sx.name: sx for sx in stmts if isinstance(sx, DefRegister)}
|
|
mem_map: Dict[str, DefMemPort] = {sx.name: sx for sx in stmts if isinstance(sx, DefMemPort)}
|
|
new_stmts: List[Statement] = []
|
|
for stmt in stmts:
|
|
if isinstance(stmt, Conditionally):
|
|
seq_map, com = expand_sequential_s(stmt, new_stmts, reg_map)
|
|
if not isinstance(com, EmptyStmt):
|
|
new_stmts.append(com)
|
|
for k in seq_map:
|
|
if k not in block_map:
|
|
block_map[k] = []
|
|
if not isinstance(seq_map[k], EmptyStmt):
|
|
block_map[k].append(seq_map[k])
|
|
else:
|
|
new_stmts.append(stmt)
|
|
reset_map: Dict[str, List[Statement]] = {}
|
|
reset_sign_map: Dict[str, List[Expression]] = {}
|
|
for reg_name in reg_map:
|
|
reg: DefRegister = reg_map[reg_name]
|
|
if reg.init is not None:
|
|
if reg.clock.verilog_serialize() not in reset_map:
|
|
reset_map[reg.clock.verilog_serialize()] = []
|
|
reset_map[reg.clock.verilog_serialize()].append(Connect(Reference(reg.name, reg.typ), reg.init, reg.info, False))
|
|
if reg.clock.verilog_serialize not in reset_sign_map:
|
|
reset_sign_map[reg.clock.verilog_serialize()] = []
|
|
reset_sign_map[reg.clock.verilog_serialize()].append(reg.reset)
|
|
for k in reset_map:
|
|
if len(reset_map[k]) > 0:
|
|
for rs in reset_sign_map[k]:
|
|
block_map[k].append(Conditionally(rs, Block(reset_map[k]), EmptyStmt()))
|
|
for k in mem_map:
|
|
mem: Reference = mem_map[k].mem
|
|
sig = DoPrim(And(), [Reference(f"{mem.name}_{k}_en", UIntType(IntWidth(1))),
|
|
Reference(f"{mem.name}_{k}_mask", UIntType(IntWidth(1)))], [], UIntType(IntWidth(1)))
|
|
con = Connect(SubAccess(mem, Reference(f"{mem.name}_{k}_addr", mem_map[k].index.typ), mem.typ),
|
|
Reference(f"{mem.name}_{k}_data", mem.typ),mem_map[k].info, False)
|
|
if mem_map[k].rw is False:
|
|
if mem_map[k].clk.verilog_serialize() not in block_map:
|
|
block_map[mem_map[k].clk.verilog_serialize()] = []
|
|
if mem_map[k].clk.verilog_serialize() not in clock_map:
|
|
clock_map[mem_map[k].clk.verilog_serialize()] = mem_map[k].clk
|
|
block_map[mem_map[k].clk.verilog_serialize()].append(Conditionally(sig, Block([con]), EmptyStmt()))
|
|
for k in block_map:
|
|
new_stmts.append(AlwaysBlock(block_map[k], clock_map[k]))
|
|
return new_stmts
|
|
|
|
def expand_sequential_m(m: DefModule) -> DefModule:
|
|
if isinstance(m, Module):
|
|
return Module(m.name, m.ports, Block(expand_sequential(m.body.stmts)), m.typ)
|
|
else:
|
|
return m
|
|
|
|
for m in c.modules:
|
|
modules.append(expand_sequential_m(m))
|
|
return Circuit(modules, c.main, c.info) |