PyHCL/pyhcl/passes/check_widths.py

182 lines
8.6 KiB
Python

from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass, PassException, Error
from pyhcl.passes.utils import get_width, get_info
from pyhcl.passes.check_types import IllegalResetType, CheckTypes, InvalidConnect
# MaxWidth
maxWidth = 1000000
class UninferredWidth(PassException):
def __init__(self, info: Info, target: str):
super().__init__(f'{info}: Uninferred width for target below. (Did you forget to assign to it?) \n{target}')
class InvalidRange(PassException):
def __init__(self, info: Info, target: str, i: Type):
super().__init__(f'{info}: Invalid range {i.serialize()} for target below. (Are the bounds valid?) \n{target}')
class WidthTooSmall(PassException):
def __init__(self, info: Info, mname: str, b: int):
super().__init__(f'{info} : [target {mname}] Width too small for constant {b}.')
class WidthTooBig(PassException):
def __init__(self, info: Info, mname: str, b: int):
super().__init__(f'{info} : [target ${mname}] Width {b} greater than max allowed width of {maxWidth} bits')
class DshlTooBig(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Width of dshl shift amount must be less than {maxWidth} bits.')
class MultiBitAsClock(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to a Clock.')
class MultiBitAsAsyncReset(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to an AsyncReset.')
class NegWidthException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [target {mname}] Width cannot be negative or zero.')
class BitsWidthException(PassException):
def __init__(self, info: Info, mname: str, hi: int, width: int, exp: str):
super().__init__(f'{info}: [target {mname}] High bit {hi} in bits operator is larger than input width {width} in {exp}.')
class HeadWidthException(PassException):
def __init__(self, info: Info, mname: str, n: int, width: int):
super().__init__(f'{info}: [target {mname}] Parameter {n} in head operator is larger than input width {width}.')
class TailWidthException(PassException):
def __init__(self, info: Info, mname: str, n: int, width: int):
super().__init__(f'{info}: [target {mname}] Parameter {n} in tail operator is larger than input width {width}.')
class CheckWidths(Pass):
def run(self, c: Circuit):
errors = Error()
def gen_target(name: str, subname: str) -> str:
return f'{name}-{subname}'
def check_width_w(info: Info, target: str, t: Type, w: Width):
if isinstance(w, IntWidth) and w.width >= maxWidth:
errors.append(WidthTooBig(info, target, w.width))
elif isinstance(w, IntWidth):
...
else:
errors.append(UninferredWidth(info, target))
def has_width(typ: Type) -> bool:
if isinstance(typ, GroundType) and hasattr(typ, 'width') and isinstance(typ.width, IntWidth):
return True
elif isinstance(typ, GroundType):
return False
else:
raise PassException(f'hasWidth - {typ}')
def check_width_t(info: Info, target: str, t: Type):
if isinstance(t, BundleType):
for f in t.fields:
check_width_f(info, target, f)
else:
for _, tt in t.__dict__.items():
if isinstance(tt, Type):
check_width_t(info, target, tt)
for _, tt in t.__dict__.items():
if isinstance(tt, Width):
check_width_w(info, target, t, tt)
def check_width_f(info: Info, target: str, f: Field):
check_width_t(info, target, f.typ)
def check_width_e_leaf(info: Info, target: str, expr: Expression):
if isinstance(expr, UIntLiteral) and get_binary_width(expr.value) > get_width(expr.width):
errors.append(WidthTooSmall(info, target, expr.value))
elif isinstance(expr, SIntLiteral) and get_binary_width(expr.value) + 1 > get_width(expr.width):
errors.append(WidthTooSmall(info, target, expr.value))
elif isinstance(expr, DoPrim) and len(expr.args) == 2:
if isinstance(expr.op, Dshl) and has_width(expr.args[0].typ) and get_width(expr.args[1].typ.width) > maxWidth:
errors.append(DshlTooBig(info, target))
elif isinstance(expr, DoPrim) and len(expr.args) == 1:
if isinstance(expr.op, Bits) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.consts[0]:
errors.append(BitsWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width), expr.serialize()))
elif isinstance(expr.op, Head) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
errors.append(HeadWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
elif isinstance(expr.op, Tail) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
errors.append(TailWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
elif isinstance(expr.op, AsClock) and expr.consts[0] != 1:
errors.append(MultiBitAsClock(info, target))
def check_width_e(info: Info, target: str, rec_depth: int, e: Expression):
check_width_e_leaf(info, target, e)
if isinstance(e, (Mux, ValidIf, DoPrim)):
if rec_depth > 0:
for _, ee in e.__dict__.items():
if isinstance(ee, Expression):
check_width_e(info, target, rec_depth - 1, ee)
else:
check_width_e_dfs(info, target, e)
def check_width_e_dfs(info: Info, target: str, expr: Expression):
stack = expr.__dict__.items()
def push(e: Expression):
stack.append(e)
while len(stack) > 0:
current = stack
check_width_e_leaf(info, target, current)
for _, leaf in current.__dict__.items():
if isinstance(leaf, Expression):
push(leaf)
def check_width_s(minfo: Info, target: str, s: Statement):
info = get_info(s)
if isinstance(info, NoInfo):
info = minfo
for _, ss in s.__dict__.items():
if isinstance(ss, Expression):
check_width_e(info, target, 4, ss)
if isinstance(ss, Statement):
check_width_s(info, target, ss)
if isinstance(ss, Type):
check_width_t(info, target, ss)
if isinstance(s, DefRegister):
sx = s.reset.typ if isinstance(s.reset, Expression) else None
if sx is None:
...
elif isinstance(sx, UIntType) and get_width(sx.width) == 1:
...
elif isinstance(sx, AsyncResetType):
...
elif isinstance(sx, ResetType):
...
else:
errors.append(IllegalResetType(info, target, s.name))
if isinstance(s.init, Expression) and CheckTypes.valid_connect(s.typ, s.init.typ) is False:
con_msg = DefRegister(s.name, s.typ, s.clock, s.reset, s.init, NoInfo())
errors.append(InvalidConnect(info, target, con_msg, _, s.init))
def check_width_p(minfo: Info, target: str, p: Port):
check_width_t(p.info, target, p.typ)
def check_width_m(target: str, m: DefModule):
if hasattr(m, 'ports') and isinstance(m.ports, list):
for p in m.ports:
check_width_p(m.info, gen_target(target, m.name), p)
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
for s in m.body.stmts:
check_width_s(m.info, gen_target(target, m.name), s)
for m in c.modules:
check_width_m(c.main, m)
errors.trigger()
return c