feat: add SingleClockStepper

This commit is contained in:
raybdzhou 2022-05-28 09:03:57 +08:00
parent 9858c6b4dd
commit 2e34e0a979
10 changed files with 170 additions and 43 deletions

View File

@ -586,10 +586,12 @@ class DefRegister(Statement):
info: Info = NoInfo()
def serialize(self) -> str:
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
if self.init is not None else ""
# i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
# if self.reset is not None else ""
if self.init:
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
if self.init is not None else ""
else:
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
if self.reset is not None else ""
return f'reg {self.name} : {self.typ.serialize()}, {self.clock.serialize()}{i}{self.info.serialize()}'
def verilog_serialize(self) -> str:
@ -925,7 +927,7 @@ class Block(Statement):
new_body = self.merge_node_s(self, {}, set())
manager = PassManager(new_body)
new_blocks = manager.renew()
CheckCombLoop.run(new_blocks)
# CheckCombLoop.run(new_blocks)
always_blocks = manager.gen_all_always_block()
return '\n'.join([stmt.verilog_serialize() for stmt in new_blocks.stmts]) + f'\n{always_blocks}' if new_blocks.stmts else ""

View File

@ -15,8 +15,6 @@ class ExpandWhens(Pass):
def flatten(s: Statement) -> List[Statement]:
new_stmts = []
conseq, alt = s.conseq, s.alt
if isinstance(conseq, EmptyStmt) or isinstance(alt, EmptyStmt):
...
if isinstance(conseq, Conditionally):
new_stmts += flatten(conseq)
if isinstance(alt, Conditionally):
@ -38,13 +36,20 @@ class ExpandWhens(Pass):
def expand_whens(stmt: Statement, stmts: List[Statement], reference: Dict[str, Expression]):
if isinstance(stmt, Conditionally):
flat_cond = flatten(stmt)
has_gen = {}
for pred, sx in flat_cond:
if isinstance(sx, Connect):
name = auto_gen_name()
loc_name = sx.loc.verilog_serialize() if isinstance(sx.loc, SubIndex) else sx.loc.serialize()
loc = sx.loc if loc_name not in reference else reference[loc_name]
stmts.append(DefNode(name, Mux(pred, sx.expr, loc, sx.expr.typ)))
reference[loc_name] = Reference(name, sx.loc.typ)
if pred.serialize() in has_gen:
s = stmts.pop()
stmts.append(DefNode(s.name, Mux(pred, has_gen[pred.serialize()], sx.expr, sx.expr.typ)))
has_gen[pred.serialize()] = sx.expr
else:
name = auto_gen_name()
loc_name = sx.loc.verilog_serialize() if isinstance(sx.loc, SubIndex) else sx.loc.serialize()
loc = sx.loc if loc_name not in reference else reference[loc_name][1]
stmts.append(DefNode(name, Mux(pred, sx.expr, loc, sx.expr.typ)))
reference[loc_name] = (sx.loc, Reference(name, sx.loc.typ))
has_gen[pred.serialize()] = sx.expr
else:
stmts.append(sx)
else:
@ -56,7 +61,7 @@ class ExpandWhens(Pass):
for stmt in stmts:
expand_whens(stmt, new_stmts, reference)
for ref in reference:
new_stmts.append(Connect(Reference(ref, reference[ref].typ), reference[ref]))
new_stmts.append(Connect(reference[ref][0], reference[ref][1]))
return new_stmts
def expand_whens_m(m: DefModule) -> DefModule:

View File

@ -33,8 +33,8 @@ class InferTypes(Pass):
if isinstance(s, DefRegister):
typs[s.name] = s.typ
clock = infer_types_e(typs, s.clock) if hasattr(s, 'clock') and isinstance(s.clock, Expression) else None
reset = UIntLiteral(0, IntWidth(1))
init = Reference(s.name, s.typ)
reset = s.reset if hasattr(s, 'reset') and isinstance(s.init, Expression) else Reference('reset', ResetType())
init = s.init if hasattr(s, 'init') and isinstance(s.init, Expression) else Reference(s.name, s.typ)
return DefRegister(s.name, s.typ, clock, reset, init, s.info)
elif isinstance(s, DefWire):
typs[s.name] = s.typ

View File

@ -78,8 +78,13 @@ class ReplaceSubaccess(Pass):
stmts.append(node)
nodes.append(node)
else:
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
Reference(nodes[-1].name, nodes[-1].value.typ), get_groud_type(get_type(target_e))))
if isinstance(nodes[-1].value, ValidIf):
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
nodes[-1].value.value, get_groud_type(get_type(target_e))))
stmts.pop()
else:
node = DefNode(auto_gen_name(), Mux(exprs[i], index[i],
Reference(nodes[-1].name, nodes[-1].value.typ), get_groud_type(get_type(target_e))))
stmts.append(node)
nodes.append(node)
return Reference(nodes[-1].name, nodes[-1].value.typ)

View File

@ -1,4 +1,3 @@
from curses import endwin
import math
from typing import List

View File

@ -0,0 +1,75 @@
from abc import ABC, abstractclassmethod
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.tester.symbol_table import SymbolTable
from pyhcl.tester.executer import TesterExecuter
class ClockStepper(ABC):
@abstractclassmethod
def bump_clock(self):
...
@abstractclassmethod
def run(self):
...
@abstractclassmethod
def get_cycle_count(self):
...
@abstractclassmethod
def combinational_bump(self):
...
class SingleClockStepper(ClockStepper):
def __init__(self, mname: str, symbol: str, executor: TesterExecuter, table: SymbolTable):
self.mname: str = mname
self.clock_symbol: str = symbol
self.executor: TesterExecuter = executor
self.table: SymbolTable = table
self.clock_cycles = 0
self.combinational_bumps = 0
def handle_name(self, name):
names = name.split(".")
names.reverse()
return names
def bump_clock(self, mname: str, clock_symbol: str, value: int):
self.table.set_symbol_value(mname, self.handle_name(clock_symbol), value)
self.clock_cycles += 1
def combinational_bump(self, value: int):
self.combinational_bumps += value
def get_cycle_count(self):
return self.clock_cycles
def run(self, steps: int):
def raise_clock():
self.table.set_symbol_value(self.mname, self.handle_name(self.clock_symbol), 1)
self.executor.execute(self.mname)
self.combinational_bumps = 0
def lower_clock():
self.table.set_symbol_value(self.mname, self.handle_name(self.clock_symbol), 0)
self.combinational_bumps = 0
for _ in range(steps):
if self.executor.get_inputchange():
self.executor.execute(self.mname)
self.clock_cycles += 1
raise_clock()
lower_clock()
class MultiClockStepper(ClockStepper):
# TODO: Add MultiCLockStepper
...

View File

@ -73,6 +73,7 @@ class TesterCompiler:
lambda table=None: get_func(table),
lambda s, table=None: set_func(s, table))
elif isinstance(expr, SubIndex):
# size = expr.expr.typ.size-1
names.append(expr.value)
e = self.gen_working_ir(mname, names, expr.expr)
get_func, set_func = e.get_func, e.set_func
@ -136,7 +137,14 @@ class TesterCompiler:
return int(value, 2)
return lambda table=None: bits(args, consts, table)
elif isinstance(op, Cat):
return lambda table=None: int(bin(args[0].get_value(table))[2:]+bin(args[1].get_value(table))[2:], 2)
def cat(args, table):
hi = args[0].get_value(table) if isinstance(args[0].get_value(table), str) else bin(args[0].get_value(table))[2:]
lo = args[1].get_value(table) if isinstance(args[1].get_value(table), str) else bin(args[1].get_value(table))[2:]
# max_len = len(hi) if len(hi) >= len(lo) else len(lo)
# hi = '{:032b}'.format(args[0].get_value(table))[-max_len:]
# lo = '{:032b}'.format(args[1].get_value(table))[-max_len:]
return hi+lo
return lambda table=None: cat(args, table)
elif isinstance(op, (AsUInt, AsSInt)):
return lambda table=None: int(args[0].get_value(table))
elif isinstance(op, AsClock):

View File

@ -5,17 +5,24 @@ from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.tester.compiler import TesterCompiler
from pyhcl.tester.symbol_table import SymbolTable
from pyhcl.tester.clock_stepper import SingleClockStepper
class TesterExecuter:
def __init__(self, circuit: Circuit):
self.circuit = circuit
self.symbol_table = SymbolTable()
self.clock_table = {}
self.inputchange = False
def handle_name(self, name):
names = name.split(".")
names.reverse()
return names
def get_inputchange(self):
return self.inputchange
def get_ref_name(self, e: Expression):
if isinstance(e, SubField):
return self.get_ref_name(e.expr)
@ -31,11 +38,6 @@ class TesterExecuter:
stmt.loc.set_value(stmt.expr.get_value(table), table)
elif isinstance(stmt, DefNode):
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.name), stmt.value.get_value(table), table)
elif isinstance(stmt, Conditionally):
if stmt.pred.get_value(table) > 0:
self.execute_stmt(m, stmt.conseq, table)
else:
self.execute_stmt(m, stmt.alt, table)
elif isinstance(stmt, Block):
for s in stmt.stmts:
self.execute_stmt(s, table)
@ -45,7 +47,7 @@ class TesterExecuter:
def execute_module(self, m: Module, ms: Dict[str, DefModule], table=None):
execute_stmts = OrderedDict()
instances = OrderedDict()
conds = OrderedDict()
conds = []
def get_in_port_name(name: str, t: Type, d: Direction) -> List[str]:
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
@ -106,7 +108,9 @@ class TesterExecuter:
elif isinstance(s, DefInstance):
instances[s.name] = s
elif isinstance(s, Conditionally):
conds[s.pred.expr.serialize()] = s
_deal_stmt(s.conseq)
_deal_stmt(s.alt)
conds.append(s)
else:
...
@ -119,12 +123,13 @@ class TesterExecuter:
for v in self.dags[m.name].travel_graph(inputs):
if v in execute_stmts:
self.execute_stmt(m, execute_stmts[v], table)
if v in conds:
stmt = conds[v]
if stmt.pred.get_value(table) > 0:
_deal_stmt(m.conseq)
else:
_deal_stmt(m.alt)
while len(conds) > 0:
cond = conds.pop(0)
if cond.pred.get_value(table):
_deal_stmt(cond.conseq)
for execute_stmt in execute_stmts.values():
self.execute_stmt(m, execute_stmt, table)
for ins in instances:
ref_module_name = instances[ins].module
@ -157,23 +162,43 @@ class TesterExecuter:
for v in self.dags[m.name].travel_graph(ref_outputs):
if v in execute_stmts:
self.execute_stmt(m, execute_stmts[v], table)
if v in conds:
stmt = conds[v]
if stmt.pred.get_value(table) > 0:
_deal_stmt(m.conseq)
else:
_deal_stmt(m.alt)
while len(conds) > 0:
cond = conds.pop(0)
if cond.pred.get_value(table):
_deal_stmt(cond.conseq)
for execute_stmt in execute_stmts.values():
self.execute_stmt(m, execute_stmt, table)
def init_clock(self, table = None):
if table is None:
table = self.symbol_table
for mname in self.symbol_table:
if mname not in self.clock_table:
self.clock_table[mname] = {}
for symbol in self.symbol_table[mname]:
self.clock_table[mname][symbol] = SingleClockStepper(mname, symbol, self, table)
def init_executer(self):
self.compiler = TesterCompiler(self.symbol_table)
self.compiled_circuit, self.dags = self.compiler.compile(self.circuit)
self.init_clock()
def set_value(self, mname: str, name: str, singal: int):
self.inputchange = True
self.symbol_table.set_symbol_value(mname, self.handle_name(name), singal)
def get_value(self, mname: str, name: str):
if self.inputchange:
self.execute(mname)
self.inputchange = False
return self.symbol_table.get_symbol_value(mname, self.handle_name(name))
def step(self, n: int, mname: str):
if n > 0:
for name in self.clock_table[mname]:
self.clock_table[mname][name].run(n)
def execute(self, mname: str):
ms = {m.name: m for m in self.compiled_circuit.modules}
m = ms[mname]

View File

@ -3,6 +3,7 @@ from pyhcl.ir.low_ir import *
@dataclass(frozen=True)
class SymbolTable:
table = {}
clock_table = {}
def gen_typ(self, typ: Type):
if isinstance(typ, (AsyncResetType, ResetType, ClockType, UIntType, SIntType)):
@ -22,6 +23,8 @@ class SymbolTable:
def set_symbol(self, mname: str, symbol):
if isinstance(symbol, Port):
if isinstance(symbol.typ, ClockType):
self.clock_table[mname][symbol.name] = self.gen_typ(symbol.typ)
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
if isinstance(symbol, DefWire):
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)

View File

@ -8,12 +8,17 @@ class Tester:
self.main = ec.main
self.executer = TesterExecuter(ec)
self.executer.init_executer()
def peek(self, name: str, value: int):
def poke(self, name: str, value: int):
self.executer.set_value(self.main, name, value)
def poke(self, name: str):
print(self.executer.get_value(self.main, name))
def peek(self, name: str) -> int:
res = self.executer.get_value(self.main, name)
return int(res, 2) if isinstance(res, str) else res
def expect(self, a, b) -> bool:
return a == b
def step(self):
self.executer.execute(self.main)
def step(self, n):
self.executer.step(n, self.main)