forked from opendacs/PyHCL
feat: add SingleClockStepper
This commit is contained in:
parent
9858c6b4dd
commit
2e34e0a979
|
@ -586,10 +586,12 @@ class DefRegister(Statement):
|
|||
info: Info = NoInfo()
|
||||
|
||||
def serialize(self) -> str:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
|
||||
if self.init is not None else ""
|
||||
# i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
|
||||
# if self.reset is not None else ""
|
||||
if self.init:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
|
||||
if self.init is not None else ""
|
||||
else:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
|
||||
if self.reset is not None else ""
|
||||
return f'reg {self.name} : {self.typ.serialize()}, {self.clock.serialize()}{i}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
|
@ -925,7 +927,7 @@ class Block(Statement):
|
|||
new_body = self.merge_node_s(self, {}, set())
|
||||
manager = PassManager(new_body)
|
||||
new_blocks = manager.renew()
|
||||
CheckCombLoop.run(new_blocks)
|
||||
# CheckCombLoop.run(new_blocks)
|
||||
always_blocks = manager.gen_all_always_block()
|
||||
|
||||
return '\n'.join([stmt.verilog_serialize() for stmt in new_blocks.stmts]) + f'\n{always_blocks}' if new_blocks.stmts else ""
|
||||
|
|
|
@ -15,8 +15,6 @@ class ExpandWhens(Pass):
|
|||
def flatten(s: Statement) -> List[Statement]:
|
||||
new_stmts = []
|
||||
conseq, alt = s.conseq, s.alt
|
||||
if isinstance(conseq, EmptyStmt) or isinstance(alt, EmptyStmt):
|
||||
...
|
||||
if isinstance(conseq, Conditionally):
|
||||
new_stmts += flatten(conseq)
|
||||
if isinstance(alt, Conditionally):
|
||||
|
@ -38,13 +36,20 @@ class ExpandWhens(Pass):
|
|||
def expand_whens(stmt: Statement, stmts: List[Statement], reference: Dict[str, Expression]):
|
||||
if isinstance(stmt, Conditionally):
|
||||
flat_cond = flatten(stmt)
|
||||
has_gen = {}
|
||||
for pred, sx in flat_cond:
|
||||
if isinstance(sx, Connect):
|
||||
name = auto_gen_name()
|
||||
loc_name = sx.loc.verilog_serialize() if isinstance(sx.loc, SubIndex) else sx.loc.serialize()
|
||||
loc = sx.loc if loc_name not in reference else reference[loc_name]
|
||||
stmts.append(DefNode(name, Mux(pred, sx.expr, loc, sx.expr.typ)))
|
||||
reference[loc_name] = Reference(name, sx.loc.typ)
|
||||
if pred.serialize() in has_gen:
|
||||
s = stmts.pop()
|
||||
stmts.append(DefNode(s.name, Mux(pred, has_gen[pred.serialize()], sx.expr, sx.expr.typ)))
|
||||
has_gen[pred.serialize()] = sx.expr
|
||||
else:
|
||||
name = auto_gen_name()
|
||||
loc_name = sx.loc.verilog_serialize() if isinstance(sx.loc, SubIndex) else sx.loc.serialize()
|
||||
loc = sx.loc if loc_name not in reference else reference[loc_name][1]
|
||||
stmts.append(DefNode(name, Mux(pred, sx.expr, loc, sx.expr.typ)))
|
||||
reference[loc_name] = (sx.loc, Reference(name, sx.loc.typ))
|
||||
has_gen[pred.serialize()] = sx.expr
|
||||
else:
|
||||
stmts.append(sx)
|
||||
else:
|
||||
|
@ -56,7 +61,7 @@ class ExpandWhens(Pass):
|
|||
for stmt in stmts:
|
||||
expand_whens(stmt, new_stmts, reference)
|
||||
for ref in reference:
|
||||
new_stmts.append(Connect(Reference(ref, reference[ref].typ), reference[ref]))
|
||||
new_stmts.append(Connect(reference[ref][0], reference[ref][1]))
|
||||
return new_stmts
|
||||
|
||||
def expand_whens_m(m: DefModule) -> DefModule:
|
||||
|
|
|
@ -33,8 +33,8 @@ class InferTypes(Pass):
|
|||
if isinstance(s, DefRegister):
|
||||
typs[s.name] = s.typ
|
||||
clock = infer_types_e(typs, s.clock) if hasattr(s, 'clock') and isinstance(s.clock, Expression) else None
|
||||
reset = UIntLiteral(0, IntWidth(1))
|
||||
init = Reference(s.name, s.typ)
|
||||
reset = s.reset if hasattr(s, 'reset') and isinstance(s.init, Expression) else Reference('reset', ResetType())
|
||||
init = s.init if hasattr(s, 'init') and isinstance(s.init, Expression) else Reference(s.name, s.typ)
|
||||
return DefRegister(s.name, s.typ, clock, reset, init, s.info)
|
||||
elif isinstance(s, DefWire):
|
||||
typs[s.name] = s.typ
|
||||
|
|
|
@ -78,8 +78,13 @@ class ReplaceSubaccess(Pass):
|
|||
stmts.append(node)
|
||||
nodes.append(node)
|
||||
else:
|
||||
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
|
||||
Reference(nodes[-1].name, nodes[-1].value.typ), get_groud_type(get_type(target_e))))
|
||||
if isinstance(nodes[-1].value, ValidIf):
|
||||
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
|
||||
nodes[-1].value.value, get_groud_type(get_type(target_e))))
|
||||
stmts.pop()
|
||||
else:
|
||||
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
|
||||
Reference(nodes[-1].name, nodes[-1].value.typ), get_groud_type(get_type(target_e))))
|
||||
stmts.append(node)
|
||||
nodes.append(node)
|
||||
return Reference(nodes[-1].name, nodes[-1].value.typ)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from curses import endwin
|
||||
import math
|
||||
|
||||
from typing import List
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
from abc import ABC, abstractclassmethod
|
||||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.tester.symbol_table import SymbolTable
|
||||
from pyhcl.tester.executer import TesterExecuter
|
||||
|
||||
class ClockStepper(ABC):
|
||||
@abstractclassmethod
|
||||
def bump_clock(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def run(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def get_cycle_count(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def combinational_bump(self):
|
||||
...
|
||||
|
||||
class SingleClockStepper(ClockStepper):
|
||||
def __init__(self, mname: str, symbol: str, executor: TesterExecuter, table: SymbolTable):
|
||||
self.mname: str = mname
|
||||
self.clock_symbol: str = symbol
|
||||
self.executor: TesterExecuter = executor
|
||||
self.table: SymbolTable = table
|
||||
self.clock_cycles = 0
|
||||
self.combinational_bumps = 0
|
||||
|
||||
def handle_name(self, name):
|
||||
names = name.split(".")
|
||||
names.reverse()
|
||||
return names
|
||||
|
||||
def bump_clock(self, mname: str, clock_symbol: str, value: int):
|
||||
self.table.set_symbol_value(mname, self.handle_name(clock_symbol), value)
|
||||
self.clock_cycles += 1
|
||||
|
||||
def combinational_bump(self, value: int):
|
||||
self.combinational_bumps += value
|
||||
|
||||
def get_cycle_count(self):
|
||||
return self.clock_cycles
|
||||
|
||||
def run(self, steps: int):
|
||||
def raise_clock():
|
||||
self.table.set_symbol_value(self.mname, self.handle_name(self.clock_symbol), 1)
|
||||
self.executor.execute(self.mname)
|
||||
|
||||
self.combinational_bumps = 0
|
||||
|
||||
def lower_clock():
|
||||
self.table.set_symbol_value(self.mname, self.handle_name(self.clock_symbol), 0)
|
||||
self.combinational_bumps = 0
|
||||
|
||||
for _ in range(steps):
|
||||
if self.executor.get_inputchange():
|
||||
self.executor.execute(self.mname)
|
||||
self.clock_cycles += 1
|
||||
|
||||
raise_clock()
|
||||
lower_clock()
|
||||
|
||||
|
||||
class MultiClockStepper(ClockStepper):
|
||||
# TODO: Add MultiCLockStepper
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -73,6 +73,7 @@ class TesterCompiler:
|
|||
lambda table=None: get_func(table),
|
||||
lambda s, table=None: set_func(s, table))
|
||||
elif isinstance(expr, SubIndex):
|
||||
# size = expr.expr.typ.size-1
|
||||
names.append(expr.value)
|
||||
e = self.gen_working_ir(mname, names, expr.expr)
|
||||
get_func, set_func = e.get_func, e.set_func
|
||||
|
@ -136,7 +137,14 @@ class TesterCompiler:
|
|||
return int(value, 2)
|
||||
return lambda table=None: bits(args, consts, table)
|
||||
elif isinstance(op, Cat):
|
||||
return lambda table=None: int(bin(args[0].get_value(table))[2:]+bin(args[1].get_value(table))[2:], 2)
|
||||
def cat(args, table):
|
||||
hi = args[0].get_value(table) if isinstance(args[0].get_value(table), str) else bin(args[0].get_value(table))[2:]
|
||||
lo = args[1].get_value(table) if isinstance(args[1].get_value(table), str) else bin(args[1].get_value(table))[2:]
|
||||
# max_len = len(hi) if len(hi) >= len(lo) else len(lo)
|
||||
# hi = '{:032b}'.format(args[0].get_value(table))[-max_len:]
|
||||
# lo = '{:032b}'.format(args[1].get_value(table))[-max_len:]
|
||||
return hi+lo
|
||||
return lambda table=None: cat(args, table)
|
||||
elif isinstance(op, (AsUInt, AsSInt)):
|
||||
return lambda table=None: int(args[0].get_value(table))
|
||||
elif isinstance(op, AsClock):
|
||||
|
|
|
@ -5,17 +5,24 @@ from pyhcl.ir.low_ir import *
|
|||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.tester.compiler import TesterCompiler
|
||||
from pyhcl.tester.symbol_table import SymbolTable
|
||||
from pyhcl.tester.clock_stepper import SingleClockStepper
|
||||
|
||||
class TesterExecuter:
|
||||
def __init__(self, circuit: Circuit):
|
||||
self.circuit = circuit
|
||||
self.symbol_table = SymbolTable()
|
||||
self.clock_table = {}
|
||||
self.inputchange = False
|
||||
|
||||
|
||||
def handle_name(self, name):
|
||||
names = name.split(".")
|
||||
names.reverse()
|
||||
return names
|
||||
|
||||
def get_inputchange(self):
|
||||
return self.inputchange
|
||||
|
||||
def get_ref_name(self, e: Expression):
|
||||
if isinstance(e, SubField):
|
||||
return self.get_ref_name(e.expr)
|
||||
|
@ -31,11 +38,6 @@ class TesterExecuter:
|
|||
stmt.loc.set_value(stmt.expr.get_value(table), table)
|
||||
elif isinstance(stmt, DefNode):
|
||||
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.name), stmt.value.get_value(table), table)
|
||||
elif isinstance(stmt, Conditionally):
|
||||
if stmt.pred.get_value(table) > 0:
|
||||
self.execute_stmt(m, stmt.conseq, table)
|
||||
else:
|
||||
self.execute_stmt(m, stmt.alt, table)
|
||||
elif isinstance(stmt, Block):
|
||||
for s in stmt.stmts:
|
||||
self.execute_stmt(s, table)
|
||||
|
@ -45,7 +47,7 @@ class TesterExecuter:
|
|||
def execute_module(self, m: Module, ms: Dict[str, DefModule], table=None):
|
||||
execute_stmts = OrderedDict()
|
||||
instances = OrderedDict()
|
||||
conds = OrderedDict()
|
||||
conds = []
|
||||
|
||||
def get_in_port_name(name: str, t: Type, d: Direction) -> List[str]:
|
||||
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
|
||||
|
@ -106,7 +108,9 @@ class TesterExecuter:
|
|||
elif isinstance(s, DefInstance):
|
||||
instances[s.name] = s
|
||||
elif isinstance(s, Conditionally):
|
||||
conds[s.pred.expr.serialize()] = s
|
||||
_deal_stmt(s.conseq)
|
||||
_deal_stmt(s.alt)
|
||||
conds.append(s)
|
||||
else:
|
||||
...
|
||||
|
||||
|
@ -119,12 +123,13 @@ class TesterExecuter:
|
|||
for v in self.dags[m.name].travel_graph(inputs):
|
||||
if v in execute_stmts:
|
||||
self.execute_stmt(m, execute_stmts[v], table)
|
||||
if v in conds:
|
||||
stmt = conds[v]
|
||||
if stmt.pred.get_value(table) > 0:
|
||||
_deal_stmt(m.conseq)
|
||||
else:
|
||||
_deal_stmt(m.alt)
|
||||
|
||||
while len(conds) > 0:
|
||||
cond = conds.pop(0)
|
||||
if cond.pred.get_value(table):
|
||||
_deal_stmt(cond.conseq)
|
||||
for execute_stmt in execute_stmts.values():
|
||||
self.execute_stmt(m, execute_stmt, table)
|
||||
|
||||
for ins in instances:
|
||||
ref_module_name = instances[ins].module
|
||||
|
@ -157,23 +162,43 @@ class TesterExecuter:
|
|||
for v in self.dags[m.name].travel_graph(ref_outputs):
|
||||
if v in execute_stmts:
|
||||
self.execute_stmt(m, execute_stmts[v], table)
|
||||
if v in conds:
|
||||
stmt = conds[v]
|
||||
if stmt.pred.get_value(table) > 0:
|
||||
_deal_stmt(m.conseq)
|
||||
else:
|
||||
_deal_stmt(m.alt)
|
||||
|
||||
while len(conds) > 0:
|
||||
cond = conds.pop(0)
|
||||
if cond.pred.get_value(table):
|
||||
_deal_stmt(cond.conseq)
|
||||
for execute_stmt in execute_stmts.values():
|
||||
self.execute_stmt(m, execute_stmt, table)
|
||||
|
||||
def init_clock(self, table = None):
|
||||
if table is None:
|
||||
table = self.symbol_table
|
||||
for mname in self.symbol_table:
|
||||
if mname not in self.clock_table:
|
||||
self.clock_table[mname] = {}
|
||||
for symbol in self.symbol_table[mname]:
|
||||
self.clock_table[mname][symbol] = SingleClockStepper(mname, symbol, self, table)
|
||||
|
||||
def init_executer(self):
|
||||
self.compiler = TesterCompiler(self.symbol_table)
|
||||
self.compiled_circuit, self.dags = self.compiler.compile(self.circuit)
|
||||
self.init_clock()
|
||||
|
||||
def set_value(self, mname: str, name: str, singal: int):
|
||||
self.inputchange = True
|
||||
self.symbol_table.set_symbol_value(mname, self.handle_name(name), singal)
|
||||
|
||||
def get_value(self, mname: str, name: str):
|
||||
if self.inputchange:
|
||||
self.execute(mname)
|
||||
self.inputchange = False
|
||||
return self.symbol_table.get_symbol_value(mname, self.handle_name(name))
|
||||
|
||||
def step(self, n: int, mname: str):
|
||||
if n > 0:
|
||||
for name in self.clock_table[mname]:
|
||||
self.clock_table[mname][name].run(n)
|
||||
|
||||
def execute(self, mname: str):
|
||||
ms = {m.name: m for m in self.compiled_circuit.modules}
|
||||
m = ms[mname]
|
||||
|
|
|
@ -3,6 +3,7 @@ from pyhcl.ir.low_ir import *
|
|||
@dataclass(frozen=True)
|
||||
class SymbolTable:
|
||||
table = {}
|
||||
clock_table = {}
|
||||
|
||||
def gen_typ(self, typ: Type):
|
||||
if isinstance(typ, (AsyncResetType, ResetType, ClockType, UIntType, SIntType)):
|
||||
|
@ -22,6 +23,8 @@ class SymbolTable:
|
|||
|
||||
def set_symbol(self, mname: str, symbol):
|
||||
if isinstance(symbol, Port):
|
||||
if isinstance(symbol.typ, ClockType):
|
||||
self.clock_table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
if isinstance(symbol, DefWire):
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
|
|
|
@ -8,12 +8,17 @@ class Tester:
|
|||
self.main = ec.main
|
||||
self.executer = TesterExecuter(ec)
|
||||
self.executer.init_executer()
|
||||
|
||||
|
||||
def peek(self, name: str, value: int):
|
||||
def poke(self, name: str, value: int):
|
||||
self.executer.set_value(self.main, name, value)
|
||||
|
||||
def poke(self, name: str):
|
||||
print(self.executer.get_value(self.main, name))
|
||||
def peek(self, name: str) -> int:
|
||||
res = self.executer.get_value(self.main, name)
|
||||
return int(res, 2) if isinstance(res, str) else res
|
||||
|
||||
def expect(self, a, b) -> bool:
|
||||
return a == b
|
||||
|
||||
def step(self):
|
||||
self.executer.execute(self.main)
|
||||
def step(self, n):
|
||||
self.executer.step(n, self.main)
|
Loading…
Reference in New Issue