core: pritty print (#33)

This commit is contained in:
Gaufoo 2019-12-17 20:20:54 +08:00 committed by GitHub
parent adb6cc6271
commit 011a8ff657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 642 additions and 247 deletions

5
.gitignore vendored
View File

@ -107,4 +107,7 @@ venv.bak/
.idea/
# VSCode
.vscode
.vscode
temp

View File

@ -1,4 +1,6 @@
def install_ops():
import py_hcl.core.expr.add # noqa: F401
import py_hcl.core.stmt.connect # noqa: F401
import py_hcl.core.expr.field_access # noqa: F401
import py_hcl.core.expr.field # noqa: F401
import py_hcl.core.expr.bits # noqa: F401
import py_hcl.core.expr.convert # noqa: F401

View File

@ -1,7 +1,10 @@
from enum import Enum
import logging
from py_hcl.core.hcl_ops import hcl_call
from py_hcl.core.type import HclType
from multipledispatch.dispatcher import MethodDispatcher
from py_hcl.core.hcl_ops import op_apply
from py_hcl.core.type import UnknownType, HclType
from py_hcl.utils import auto_repr
@ -14,15 +17,56 @@ class ConnDir(Enum):
@auto_repr
class HclExpr(object):
def __init__(self):
self.hcl_type = HclType()
self.conn_dir = ConnDir.UNKNOWN
hcl_type = UnknownType()
conn_dir = ConnDir.UNKNOWN
def __ilshift__(self, other):
return hcl_call('<<=')(self, other)
return op_apply('<<=')(self, other)
def __add__(self, other):
return hcl_call('+')(self, other)
return op_apply('+')(self, other)
def __getattr__(self, item):
return hcl_call('.')(self, item)
return op_apply('.')(self, item)
__getitem__ = MethodDispatcher('__getitem__')
@__getitem__.register(tuple)
def _(self, item):
logging.warning('slice(): too many index blocks, '
'only the first one takes effect')
item = item[0]
return self.__getitem__(item)
@__getitem__.register(slice)
def _(self, item):
"""
o[5:2]
"""
assert item.start is not None
assert item.stop is not None
assert item.step is None
return op_apply('[i:j]')(self, item.start, item.stop)
@__getitem__.register(int)
def _(self, item):
"""
o[5]
"""
return op_apply('[i]')(self, item)
def to_uint(self):
return op_apply('to_uint')(self)
def to_sint(self):
return op_apply('to_sint')(self)
def to_bool(self):
return op_apply('to_bool')(self)
class ExprHolder(HclExpr):
def __init__(self, hcl_type: HclType, conn_dir: ConnDir, assoc_value):
self.hcl_type = hcl_type
self.conn_dir = conn_dir
self.assoc_value = assoc_value

View File

@ -1,8 +1,7 @@
from py_hcl.core.expr.error import ExprError
from py_hcl.core.expr.place import ExprPlace
from py_hcl.core.expr import ConnDir
from py_hcl.core.hcl_ops import hcl_operation
from py_hcl.core.type import HclType
from py_hcl.core.expr import ConnDir, ExprHolder
from py_hcl.core.hcl_ops import op_register
from py_hcl.core.type.sint import SIntT
from py_hcl.core.type.uint import UIntT
from py_hcl.utils import auto_repr
@ -14,22 +13,30 @@ class Add(object):
self.right = right
adder = hcl_operation('+')
adder = op_register('+')
@adder(UIntT, UIntT)
def _(lf, rt):
check_add_dir(lf, rt)
w = max(lf.hcl_type.width, rt.hcl_type.width) + 1
t = UIntT(w)
return ExprPlace(t, Add(lf, rt), ConnDir.RT)
return ExprHolder(t, ConnDir.RT, Add(lf, rt))
@adder(HclType, HclType)
@adder(SIntT, SIntT)
def _(lf, rt):
# TODO: temporary
return ExprPlace(HclType(), Add(lf, rt), ConnDir.RT)
check_add_dir(lf, rt)
w = max(lf.hcl_type.width, rt.hcl_type.width) + 1
t = SIntT(w)
return ExprHolder(t, ConnDir.RT, Add(lf, rt))
@adder(object, object)
def _(lf: object, rt: object):
raise ExprError.add(lf, rt)
def _(_0, _1):
raise ExprError.op_type_err('add', _0, _1)
def check_add_dir(lf, rt):
assert lf.conn_dir in (ConnDir.RT, ConnDir.BOTH)
assert rt.conn_dir in (ConnDir.RT, ConnDir.BOTH)

56
py_hcl/core/expr/bits.py Normal file
View File

@ -0,0 +1,56 @@
from py_hcl.core.expr import ExprHolder, ConnDir
from py_hcl.core.expr.error import ExprError
from py_hcl.core.hcl_ops import op_register
from py_hcl.core.type.sint import SIntT
from py_hcl.core.type.uint import UIntT
from py_hcl.utils import auto_repr
slice_ = op_register('[i:j]')
index = op_register('[i]')
@auto_repr
class Bits(object):
def __init__(self, expr, high, low):
self.high = high
self.low = low
self.expr = expr
@slice_(UIntT)
def _(uint, high: int, low: int):
check_bit_width(uint, high, low)
t = UIntT(high - low + 1)
return ExprHolder(t, ConnDir.RT, Bits(uint, high, low))
@slice_(SIntT)
def _(sint, high: int, low: int):
check_bit_width(sint, high, low)
t = UIntT(high - low + 1)
return ExprHolder(t, ConnDir.RT, Bits(sint, high, low))
@slice_(object)
def _(_0, *_):
ExprError.op_type_err('slice', _0)
@index(UIntT)
def _(uint, i: int):
return uint[i:i]
@index(SIntT)
def _(sint, i: int):
return sint[i:i]
@index(object)
def _(_0, *_):
ExprError.op_type_err('index', _0)
def check_bit_width(uint, high, low):
w = uint.hcl_type.width
assert w > high >= low >= 0

View File

@ -0,0 +1,53 @@
from py_hcl.core.expr import ExprHolder, ConnDir
from py_hcl.core.hcl_ops import op_register
from py_hcl.core.type.sint import SIntT
from py_hcl.core.type.uint import UIntT
from py_hcl.utils import auto_repr
to_bool = op_register('to_bool')
to_uint = op_register('to_uint')
to_sint = op_register('to_sint')
@auto_repr
class ToSInt(object):
def __init__(self, expr):
self.expr = expr
@auto_repr
class ToUInt(object):
def __init__(self, expr):
self.expr = expr
@to_bool(UIntT)
def _(uint):
return uint[0]
@to_bool(SIntT)
def _(sint):
return sint[0]
@to_uint(UIntT)
def _(uint):
return uint
@to_uint(SIntT)
def _(sint):
t = UIntT(sint.hcl_type.width)
return ExprHolder(t, ConnDir.RT, ToUInt(sint))
@to_sint(UIntT)
def _(uint):
t = SIntT(uint.hcl_type.width)
return ExprHolder(t, ConnDir.RT, ToSInt(uint))
@to_sint(SIntT)
def _(sint):
return sint

View File

@ -6,7 +6,7 @@ def set_up():
'IOValueError': {
'code': 200,
'value': ExprError('io items should wrap with Input or Output')},
'AddError': {
'OpTypeError': {
'code': 201,
'value': ExprError('specified arguments contain unexpected types')
}
@ -19,10 +19,10 @@ class ExprError(CoreError):
return ExprError.err('IOValueError', msg)
@staticmethod
def add(lf, rt):
return ExprError.err('AddError',
'unsupported operand types: {} and {}'.format(
type(lf), type(rt)))
def op_type_err(op, *args):
ts = ', '.join([type(a.hcl_type).__name__ for a in args])
msg = '{}(): unsupported operand types: {}'.format(op, ts)
return ExprError.err('OpTypeError', msg)
set_up()

27
py_hcl/core/expr/field.py Normal file
View File

@ -0,0 +1,27 @@
from py_hcl.core.expr import ConnDir, ExprHolder
from py_hcl.core.expr.error import ExprError
from py_hcl.core.hcl_ops import op_register
from py_hcl.core.type.bundle import BundleT, Dir
from py_hcl.utils import auto_repr
field_accessor = op_register('.')
@auto_repr
class FieldAccess(object):
def __init__(self, expr, item):
self.item = item
self.expr = expr
@field_accessor(BundleT)
def _(bd, item):
assert item in bd.hcl_type.types
dr, tpe = bd.hcl_type.types[item]
cd = ConnDir.RT if dr == Dir.IN else ConnDir.LF
return ExprHolder(tpe, cd, FieldAccess(bd, item))
@field_accessor(object)
def _(o, *_):
raise ExprError.op_type_err('field_accessor', o)

View File

@ -1,22 +0,0 @@
from py_hcl.core.expr.place import ExprPlace
from py_hcl.core.expr import ConnDir
from py_hcl.core.hcl_ops import hcl_operation
from py_hcl.core.type.bundle import BundleT, Dir
from py_hcl.utils import auto_repr
field_accessor = hcl_operation('.')
@auto_repr
class FieldAccess(object):
def __init__(self, obj, item):
self.obj = obj
self.item = item
@field_accessor(BundleT)
def _(bd, item):
assert item in bd.hcl_type.types
dr, tpe = bd.hcl_type.types[item]
cd = ConnDir.RT if dr == Dir.IN else ConnDir.LF
return ExprPlace(tpe, FieldAccess(bd, item), cd)

View File

@ -1,12 +1,24 @@
from typing import Dict, Union
from py_hcl.core.expr.error import ExprError
from py_hcl.core.expr import HclExpr, ConnDir
from py_hcl.core.type import HclType
from py_hcl.core.type.bundle import Dir, BundleT
from py_hcl.utils import _fm, _indent
class Input(object):
def __init__(self, hcl_type: HclType):
self.hcl_type = hcl_type
class Output(object):
def __init__(self, hcl_type: HclType):
self.hcl_type = hcl_type
class IO(HclExpr):
def __init__(self, **named_ports):
super().__init__()
def __init__(self, named_ports: Dict[str, Union[Input, Output]]):
self.hcl_type = IO.handle_args(named_ports)
self.conn_dir = ConnDir.BOTH
@ -27,12 +39,8 @@ class IO(HclExpr):
return BundleT(types)
class Input(object):
def __init__(self, hcl_type: HclType):
self.hcl_type = hcl_type
class Output(object):
def __init__(self, hcl_type: HclType):
self.hcl_type = hcl_type
def __repr__(self):
return 'IO {\n' \
' conn_dir=%s\n' \
' hcl_type=%s\n' \
'}' % (self.conn_dir, _indent(_fm(self.hcl_type)))

View File

@ -0,0 +1,12 @@
from py_hcl.core.expr import HclExpr, ConnDir
from py_hcl.core.type.sint import SIntT
from py_hcl.utils import signed_num_bin_len
class SLiteral(HclExpr):
def __init__(self, value: int):
self.value = value
w = signed_num_bin_len(value)
self.hcl_type = SIntT(w)
self.conn_dir = ConnDir.RT

View File

@ -5,8 +5,8 @@ from py_hcl.utils import unsigned_num_bin_len
class ULiteral(HclExpr):
def __init__(self, value: int):
super().__init__()
self.value = value
w = unsigned_num_bin_len(value)
self.hcl_type = UIntT(w)
self.conn_dir = ConnDir.RT

View File

@ -1,10 +0,0 @@
from py_hcl.core.expr import HclExpr, ConnDir
from py_hcl.core.type import HclType
class ExprPlace(HclExpr):
def __init__(self, hcl_type: HclType, assoc_value, conn_dir: ConnDir):
super().__init__()
self.hcl_type = hcl_type
self.assoc_value = assoc_value
self.conn_dir = conn_dir

View File

@ -3,18 +3,30 @@ from multipledispatch import Dispatcher
op_map = {
'+': Dispatcher('+'),
'<<=': Dispatcher('<<='),
'.': Dispatcher('.')
'.': Dispatcher('.'),
'[i]': Dispatcher('[i]'),
'[i:j]': Dispatcher('[i:j]'),
'to_sint': Dispatcher('to_sint'),
'to_uint': Dispatcher('to_uint'),
'to_bool': Dispatcher('to_bool'),
}
def hcl_operation(operation):
def op_register(operation):
return op_map[operation].register
def hcl_call(operation):
def op_apply(operation):
def _(*objects):
types = [type(o.hcl_type) for o in objects if hasattr(o, 'hcl_type')]
func = op_map[operation].dispatch(*types)
return func(*objects)
if func is not None:
return func(*objects)
msg = 'No matched functions for types {} while calling operation ' \
'"{}"'.format([t.__name__ for t in types], operation)
raise NotImplementedError(msg)
return _

View File

@ -1,4 +1,4 @@
from . import packer
from py_hcl.core.module_factory import packer
class MetaModule(type):

View File

@ -1,5 +1,9 @@
from py_hcl.utils import auto_repr
@auto_repr
class PackedModule(object):
def __init__(self, name, named_expressions, top_scope):
def __init__(self, name, named_expressions, top_statement):
self.name = name
self.named_expressions = named_expressions
self.top_scope = top_scope
self.top_statement = top_statement

View File

View File

@ -5,7 +5,7 @@ def set_up():
ModuleError.append({
'NotContainsIO': {
'code': 100,
'value': ModuleError('the module lack of io attribute')},
'value': ModuleError('the module_factory lack of io attribute')},
'InheritDuplicateName': {
'code': 101,

View File

@ -1,4 +1,4 @@
from py_hcl.core.module.error import ModuleError
from py_hcl.core.module_factory.error import ModuleError
from py_hcl.core.expr import HclExpr
@ -18,4 +18,4 @@ def extract(dct, name):
def check_io_exist(res, name):
if 'io' not in res:
raise ModuleError.not_contains_io(
'module {} lack of io attribute'.format(name))
'module_factory {} lack of io attribute'.format(name))

View File

@ -1,5 +1,5 @@
from py_hcl.core.module.error import ModuleError
from py_hcl.dsl.expr.io import IO
from py_hcl.core.expr.io import IO
from py_hcl.core.module_factory.error import ModuleError
def merge_expr(dest, src, mod_names):
@ -22,8 +22,8 @@ def check_dup_mod(dest, src, mod_names):
dest_name = mod_names[0]
src_name = mod_names[1]
raise ModuleError.duplicate_name(
'module {} has duplicates with {} in '
'module {}'.format(dest_name, list(a), src_name)
'module_factory {} has duplicates with {} in '
'module_factory {}'.format(dest_name, list(a), src_name)
)
@ -44,8 +44,8 @@ def check_dup_io(dest, src, mod_names):
dest_name = mod_names[0]
src_name = mod_names[1]
raise ModuleError.duplicate_name(
'module {} has duplicates with {} in '
'module {} in io'.format(dest_name, p, src_name)
'module_factory {} has duplicates with {} in '
'module_factory {} in io'.format(dest_name, p, src_name)
)

View File

@ -1,31 +1,31 @@
from ..stmt.trapper import StatementTrapper
from py_hcl.core.stmt_factory.trapper import StatementTrapper
from . import merger
from . import extractor
from .packed_module import PackedModule
from py_hcl.core.module.packed_module import PackedModule
def pack(bases, dct, name):
raw_expr = extractor.extract(dct, name)
raw_scope = StatementTrapper.trap()
named_expression, top_scope = \
named_expression, top_statement = \
handle_inherit(bases, raw_expr, raw_scope, name)
res = PackedModule(name, named_expression, top_scope)
res = PackedModule(name, named_expression, top_statement)
return res
def handle_inherit(bases, named_expression, top_scope, name):
def handle_inherit(bases, named_expression, top_statement, name):
for b in bases:
if not hasattr(b, 'packed_module'):
continue
pm = b.packed_module
expr = pm.named_expressions
ts = pm.top_scope
ts = pm.top_statement
named_expression = merger.merge_expr(named_expression, expr,
(name, pm.name))
top_scope = merger.merge_scope(top_scope, ts, (name, pm.name))
top_statement = merger.merge_scope(top_statement, ts, (name, pm.name))
return named_expression, top_scope
return named_expression, top_statement

View File

@ -0,0 +1,15 @@
from py_hcl.utils import auto_repr
@auto_repr
class LineStatement(object):
def __init__(self, scope_id, statement):
self.scope_id = scope_id
self.statement = statement
@auto_repr
class BlockStatement(object):
def __init__(self, scope_info, stmts):
self.scope_info = scope_info
self.statements = stmts

View File

@ -1,25 +1,31 @@
from py_hcl.core.expr import HclExpr
from py_hcl.core.stmt.error import StatementError
from py_hcl.core.stmt.scope import ScopeManager, ScopeType
from py_hcl.core.stmt.trapper import StatementTrapper
from py_hcl.core.stmt_factory.scope import ScopeManager, ScopeType
from py_hcl.core.stmt import BlockStatement
from py_hcl.core.stmt_factory.trapper import StatementTrapper
from py_hcl.core.type.uint import UIntT
from py_hcl.utils import auto_repr
@auto_repr
class When(object):
def __init__(self, cond):
def __init__(self, cond: HclExpr):
self.cond = cond
@auto_repr
class ElseWhen(object):
def __init__(self, cond):
def __init__(self, cond: HclExpr):
self.cond = cond
@auto_repr
class Otherwise(object):
pass
def do_when_enter(cond_expr):
# TODO: need some checks
print('do_when_enter: need some check')
def do_when_enter(cond_expr: HclExpr):
check_bool_expr(cond_expr)
w = When(cond_expr)
ScopeManager.expand_scope(ScopeType.WHEN, w)
@ -29,12 +35,10 @@ def do_when_exit():
ScopeManager.shrink_scope()
def do_else_when_enter(cond_expr):
def do_else_when_enter(cond_expr: HclExpr):
check_bool_expr(cond_expr)
check_branch_syntax()
# TODO: need some checks
print('do_else_when_enter: need some check')
e = ElseWhen(cond_expr)
ScopeManager.expand_scope(ScopeType.ELSE_WHEN, e)
@ -54,20 +58,56 @@ def do_otherwise_exit():
ScopeManager.shrink_scope()
def check_bool_expr(cond_expr: HclExpr):
if isinstance(cond_expr.hcl_type, UIntT) and cond_expr.hcl_type.width == 1:
return
raise StatementError.wrong_branch_syntax(
'check_bool_expr(): '
'expected bool-type expression')
def check_branch_syntax():
check_exists_pre_stmts()
check_exists_pre_when_block()
check_correct_block_level()
def check_exists_pre_stmts():
if len(StatementTrapper.trapped_stmts[-1]) == 0:
raise StatementError.wrong_branch_syntax('expected when block')
raise StatementError.wrong_branch_syntax(
'check_exists_pre_stmts(): '
'expected when block')
def check_exists_pre_when_block():
last_stmt = StatementTrapper.trapped_stmts[-1][-1]
last_scope = last_stmt['scope']
last_scope_type = last_scope['scope_type']
if last_scope_type != ScopeType.WHEN and \
last_scope_type != ScopeType.ELSE_WHEN:
raise StatementError.wrong_branch_syntax('expected when block or '
'else_when block')
current_scope = ScopeManager.current_scope()
last_scope_level = last_scope['scope_level']
current_scope_level = current_scope['scope_level']
if last_scope_level != current_scope_level + 1:
raise StatementError.wrong_branch_syntax('branch block not matched')
if isinstance(last_stmt, BlockStatement):
last_scope = last_stmt.scope_info
last_scope_type = last_scope['scope_type']
when = last_scope_type == ScopeType.WHEN
else_when = last_scope_type == ScopeType.ELSE_WHEN
if when or else_when:
return
raise StatementError.wrong_branch_syntax(
'check_exists_pre_when_block(): '
'expected when block or else_when block')
def check_correct_block_level():
last_stmt = StatementTrapper.trapped_stmts[-1][-1]
if isinstance(last_stmt, BlockStatement):
last_scope = last_stmt.scope_info
current_scope = ScopeManager.current_scope()
last_scope_level = last_scope['scope_level']
current_scope_level = current_scope['scope_level']
if last_scope_level == current_scope_level + 1:
return
raise StatementError.wrong_branch_syntax(
'check_correct_block_level(): '
'branch block not matched')

View File

@ -1,26 +1,68 @@
import logging
from py_hcl.core.expr import ConnDir
from py_hcl.core.hcl_ops import hcl_operation
from py_hcl.core.stmt.trapper import StatementTrapper
from py_hcl.core.hcl_ops import op_register, op_apply
from py_hcl.core.stmt.error import StatementError
from py_hcl.core.stmt_factory.trapper import StatementTrapper
from py_hcl.core.type.sint import SIntT
from py_hcl.core.type.uint import UIntT
from py_hcl.utils import auto_repr
@auto_repr
class Connect(object):
def __init__(self, left, right):
self.left = left
self.right = right
connector = hcl_operation('<<=')
connector = op_register('<<=')
@connector(UIntT, UIntT)
def _(left, right):
check_connect_dir(left, right)
if left.hcl_type.width < right.hcl_type.width:
msg = 'connect(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type)
logging.warning(msg)
right = right[left.hcl_type.width - 1:0]
assert left.hcl_type.width >= right.hcl_type.width
StatementTrapper.track(Connect(left, right))
return left
@connector(SIntT, SIntT)
def _(left, right):
check_connect_dir(left, right)
if left.hcl_type.width < right.hcl_type.width:
logging.warning(
'connector(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type
))
right = right[left.hcl_type.width - 1:0].to_sint()
StatementTrapper.track(Connect(left, right))
return left
@connector(UIntT, SIntT)
def _(left, right):
msg = 'connect(): connecting SInt to UInt, an auto-conversion will occur'
logging.warning(msg)
return op_apply('<<=')(left, right.to_uint())
@connector(SIntT, UIntT)
def _(left, right):
msg = 'connect(): connecting SInt to UInt, an auto-conversion will occur'
logging.warning(msg)
return op_apply('<<=')(left, right.to_sint())
@connector(object, object)
def _(left, right):
check_connect_dir(left, right)
# TODO: need some check
print('do_connect: need some check')
StatementTrapper.track(Connect(left, right))
return left
def _(_0, _1):
raise StatementError.connect_type_error(_0, _1)
def check_connect_dir(left, right):

View File

@ -7,7 +7,11 @@ def set_up():
'code': 300,
'value': StatementError(
'expected a well-defined when-else_when-otherwise block')},
'ConnectTypeError': {
'code': 301,
'value': StatementError(
'connect statement contains unexpected types')
}
})
@ -16,5 +20,11 @@ class StatementError(CoreError):
def wrong_branch_syntax(msg):
return StatementError.err('WrongBranchSyntax', msg)
@staticmethod
def connect_type_error(*args):
ts = ', '.join([type(a.hcl_type).__name__ for a in args])
msg = 'connect(): unsupported connect types: {}'.format(ts)
return StatementError.err('ConnectTypeError', msg)
set_up()

View File

@ -1,54 +0,0 @@
from .scope import Scope, ScopeManager, ScopeType
def set_up():
ScopeManager.register_scope_expanding(
StatementTrapper.when_scope_expanding)
ScopeManager.register_scope_shrinking(
StatementTrapper.when_scope_shrinking
)
ScopeManager.expand_scope(ScopeType.GROUND)
class StatementTrapper(object):
trapped_stmts = [[
# TOP SCOPE which contains all modules
]]
@classmethod
def trap(cls):
ScopeManager.shrink_scope()
assert len(cls.trapped_stmts) == 1
t = cls.trapped_stmts[0][-1]
ret = Scope(t['scope'], t['statement'])
ScopeManager.expand_scope(ScopeType.GROUND)
return ret # TODO
@classmethod
def track(cls, statement):
# TODO: wrap the statement
print("StatementTrapper.trace: need wrap the statement")
statement = {
'scope': ScopeManager.current_scope(),
'statement': statement,
}
cls.trapped_stmts[-1].append(statement)
@classmethod
def when_scope_expanding(cls, current_scope, next_scope):
cls.trapped_stmts.append([])
@classmethod
def when_scope_shrinking(cls, current_scope, next_scope):
stmts = cls.trapped_stmts.pop()
cls.trapped_stmts[-1].append({
'scope': current_scope,
'statement': stmts
})
set_up()

View File

View File

@ -9,12 +9,6 @@ class ScopeType(Enum):
OTHERWISE = 4
class Scope(object):
def __init__(self, scope, stmts):
self.scope = scope
self.statements = stmts
class ScopeLevelManager(object):
_next_scope_level = 0
@ -49,6 +43,9 @@ class ScopeManager(object):
'scope_type': ScopeType.TOP,
'tag_object': None
}]
scope_id_map = {
scope_list[0]['scope_id']: scope_list[0]
}
scope_expanding_hooks = []
scope_shrinking_hooks = []
@ -57,12 +54,14 @@ class ScopeManager(object):
ScopeLevelManager.expand_level()
current_scope = cls.current_scope()
next_id = ScopeIdManager.next_id()
next_scope = {
'scope_id': ScopeIdManager.next_id(),
'scope_id': next_id,
'scope_level': ScopeLevelManager.current_level(),
'scope_type': scope_type,
'tag_object': tag_object
}
cls.scope_id_map[next_id] = next_scope
cls.scope_list.append(next_scope)
for fn in cls.scope_expanding_hooks:
@ -81,6 +80,10 @@ class ScopeManager(object):
def current_scope(cls):
return cls.scope_list[-1]
@classmethod
def get_scope_info(cls, sid):
return cls.scope_id_map[sid]
@classmethod
def register_scope_expanding(cls, fn):
cls.scope_expanding_hooks.append(fn)

View File

@ -0,0 +1,47 @@
from py_hcl.core.stmt import LineStatement, BlockStatement
from .scope import ScopeManager, ScopeType
def set_up():
ScopeManager.register_scope_expanding(
StatementTrapper.on_scope_expanding)
ScopeManager.register_scope_shrinking(
StatementTrapper.on_scope_shrinking)
ScopeManager.expand_scope(ScopeType.GROUND)
class StatementTrapper(object):
trapped_stmts = [[
# TOP SCOPE which contains all modules
]]
@classmethod
def trap(cls):
ScopeManager.shrink_scope()
assert len(cls.trapped_stmts) == 1
ret = cls.trapped_stmts[0][-1]
ScopeManager.expand_scope(ScopeType.GROUND)
return ret # TODO
@classmethod
def track(cls, statement):
statement = LineStatement(
ScopeManager.current_scope()['scope_id'],
statement
)
cls.trapped_stmts[-1].append(statement)
@classmethod
def on_scope_expanding(cls, current_scope, next_scope):
cls.trapped_stmts.append([])
@classmethod
def on_scope_shrinking(cls, current_scope, next_scope):
stmts = cls.trapped_stmts.pop()
cls.trapped_stmts[-1].append(BlockStatement(current_scope, stmts))
set_up()

View File

@ -4,3 +4,7 @@ from py_hcl.utils import auto_repr
@auto_repr
class HclType(object):
pass
class UnknownType(HclType):
pass

15
py_hcl/core/type/sint.py Normal file
View File

@ -0,0 +1,15 @@
from py_hcl.core.type import HclType
from py_hcl.utils import signed_num_bin_len
class SIntT(HclType):
def __init__(self, width):
self.width = width
def __call__(self, value: int):
from py_hcl.core.expr.lit_sint import SLiteral
assert signed_num_bin_len(value) <= self.width
u = SLiteral(value)
u.hcl_type = SIntT(self.width)
return u

View File

@ -7,9 +7,9 @@ class UIntT(HclType):
self.width = width
def __call__(self, value: int):
from py_hcl.core.expr.uint_lit import ULiteral
from py_hcl.core.expr.lit_uint import ULiteral
assert unsigned_num_bin_len(value) <= self.width
u = ULiteral(value)
u.hcl_type = self
u.hcl_type = UIntT(self.width)
return u

View File

@ -1,8 +1,9 @@
import py_hcl.core.stmt.branch as cb
from py_hcl.core.expr import HclExpr
class when(object):
def __init__(self, cond_expr):
def __init__(self, cond_expr: HclExpr):
self.cond_expr = cond_expr
def __enter__(self):
@ -13,7 +14,7 @@ class when(object):
class else_when(object):
def __init__(self, cond_expr):
def __init__(self, cond_expr: HclExpr):
self.cond_expr = cond_expr
def __enter__(self):

View File

@ -1,5 +1,16 @@
import py_hcl.core.expr.io as cio
from typing import Union
IO = cio.IO
Input = cio.Input
Output = cio.Output
import py_hcl.core.expr.io as cio
from py_hcl.core.type import HclType
def IO(**named_ports: Union[cio.Input, cio.Output]) -> cio.IO:
return cio.IO(named_ports)
def Input(hcl_type: HclType) -> cio.Input:
return cio.Input(hcl_type)
def Output(hcl_type: HclType) -> cio.Output:
return cio.Output(hcl_type)

View File

@ -1,3 +1,6 @@
import py_hcl.core.expr.wire as cwr
from py_hcl.core.type import HclType
Wire = cwr.Wire
def Wire(hcl_type: HclType) -> cwr.Wire:
return cwr.Wire(hcl_type)

14
py_hcl/dsl/tpe/sint.py Normal file
View File

@ -0,0 +1,14 @@
from py_hcl.core.expr.lit_sint import SLiteral
from py_hcl.core.type.sint import SIntT
class _(object):
def __call__(self, value: int) -> SLiteral:
return SLiteral(value)
@staticmethod
def w(width: int) -> SIntT:
return SIntT(width)
S = _()

View File

@ -1,15 +1,15 @@
from py_hcl.core.type.uint import UIntT
from py_hcl.core.expr.uint_lit import ULiteral
from py_hcl.core.expr.lit_uint import ULiteral
class U(object):
def __call__(self, value: int):
class _(object):
def __call__(self, value: int) -> ULiteral:
return ULiteral(value)
@staticmethod
def w(width: int):
def w(width: int) -> UIntT:
return UIntT(width)
U = U()
U = _()
Bool = U.w(1)

View File

@ -33,7 +33,7 @@ class DefModule(Statement):
self.body = body
def serialize_stmt(self, output, indent):
output.write(b"module ")
output.write(b"module_factory ")
output.write(serialize_str(self.name))
output.write(b" :\n")
indent += 1

View File

@ -1,5 +1,5 @@
"""
The tpe module provides type nodes in FIRRTL IR.
The tpe module_factory provides type nodes in FIRRTL IR.
At the top level of types include UnknownType, GroundType, AggregateType.
GroundType acts as a primitive type, and AggregateType is similar to a

View File

@ -1,5 +1,5 @@
"""
The field module provides information about the field
The field module_factory provides information about the field
in BundleType.
Fields are allowed to be flipped, indicating that it's

View File

@ -1,5 +1,5 @@
"""
The width module provides bit width information for UIntType
The width module_factory provides bit width information for UIntType
and SIntType.
"""

View File

@ -1,3 +1,6 @@
from multipledispatch import dispatch
def signed_num_bin_len(num):
return len("{:+b}".format(num))
@ -8,10 +11,48 @@ def unsigned_num_bin_len(num):
def auto_repr(cls):
def __repr__(self):
return '%s(%s)' % (
type(self).__name__,
', '.join('%s = %s' % item for item in vars(self).items())
)
ls = ['{}={}'.format(k, _fm(v)) for k, v in vars(self).items()]
fs = _iter_repr(ls)
return '%s {%s}' % (type(self).__name__, ''.join(fs))
cls.__repr__ = __repr__
return cls
@dispatch()
def _fm(vd: dict):
ls = ['{}: {}'.format(k, _fm(v)) for k, v in vd.items()]
fs = _iter_repr(ls)
return '{%s}' % (''.join(fs))
@dispatch()
def _fm(v: list):
ls = [_fm(a) for a in v]
fs = _iter_repr(ls)
return '[%s]' % (''.join(fs))
@dispatch()
def _fm(v: tuple):
ls = [_fm(a) for a in v]
fs = _iter_repr(ls)
return '(%s)' % (''.join(fs))
@dispatch()
def _fm(v: object):
return str(v)
def _iter_repr(ls):
if len(ls) <= 1:
fs = ''.join(ls)
else:
fs = ''.join(['\n {},'.format(_indent(l)) for l in ls]) + '\n'
return fs
def _indent(s: str) -> str:
s = s.split('\n')
return '\n '.join(s)

View File

@ -1,9 +1,8 @@
import pytest
from py_hcl.core.expr import HclExpr
from py_hcl.core.stmt.connect import Connect
from py_hcl.core.stmt.error import StatementError
from py_hcl.core.stmt.scope import ScopeType
from py_hcl.core.stmt_factory.scope import ScopeType, ScopeManager
from py_hcl.dsl.branch import when, else_when, otherwise
from py_hcl.dsl.expr.io import IO
from py_hcl.dsl.expr.wire import Wire
@ -19,39 +18,47 @@ def test_branch():
c = Wire(U.w(8))
a <<= b
with when(HclExpr()):
with when(U(0)):
a <<= b + c
c <<= a
with else_when(HclExpr()):
with else_when(U(1)):
b <<= a + c
with when(HclExpr()):
with when(U(0)):
b <<= a
with otherwise():
c <<= a
with otherwise():
c <<= a + b
s = A.packed_module.top_scope.statements
s = A.packed_module.top_statement.statements
assert len(s) == 4
assert s[0]['scope']['scope_type'] == ScopeType.GROUND
assert isinstance(s[0]['statement'], Connect)
si = ScopeManager.get_scope_info(s[0].scope_id)
assert si['scope_type'] == ScopeType.GROUND
assert isinstance(s[0].statement, Connect)
assert s[1]['scope']['scope_type'] == ScopeType.WHEN
assert s[1]['scope']['scope_level'] == 2
assert len(s[1]['statement']) == 2
si = s[1].scope_info
assert si['scope_type'] == ScopeType.WHEN
assert si['scope_level'] == 2
assert len(s[1].statements) == 2
assert s[2]['scope']['scope_type'] == ScopeType.ELSE_WHEN
assert s[2]['scope']['scope_level'] == 2
assert len(s[2]['statement']) == 3
assert s[2]['statement'][1]['scope']['scope_type'] == ScopeType.WHEN
assert s[2]['statement'][1]['scope']['scope_level'] == 3
assert s[2]['statement'][2]['scope']['scope_type'] == ScopeType.OTHERWISE
assert s[2]['statement'][2]['scope']['scope_level'] == 3
si = s[2].scope_info
assert si['scope_type'] == ScopeType.ELSE_WHEN
assert si['scope_level'] == 2
assert len(s[2].statements) == 3
assert s[3]['scope']['scope_type'] == ScopeType.OTHERWISE
assert s[3]['scope']['scope_level'] == 2
assert len(s[3]['statement']) == 1
si = s[2].statements[1].scope_info
assert si['scope_type'] == ScopeType.WHEN
assert si['scope_level'] == 3
si = s[2].statements[2].scope_info
assert si['scope_type'] == ScopeType.OTHERWISE
assert si['scope_level'] == 3
si = s[3].scope_info
assert si['scope_type'] == ScopeType.OTHERWISE
assert si['scope_level'] == 2
assert len(s[3].statements) == 1
def test_branch_syntax_error1():
@ -63,7 +70,7 @@ def test_branch_syntax_error1():
c = Wire(U.w(8))
a <<= b
with else_when(HclExpr()):
with else_when(U(0)):
b <<= a + c
with otherwise():
c <<= a + b
@ -89,7 +96,7 @@ def test_branch_syntax_error3():
b = Wire(U.w(8))
c = Wire(U.w(8))
with when(HclExpr()):
with when(U(0)):
b <<= a + c
with otherwise():
c <<= a + b
@ -103,9 +110,9 @@ def test_branch_syntax_error4():
b = Wire(U.w(8))
c = Wire(U.w(8))
with when(HclExpr()):
with when(U(0)):
b <<= a + c
with else_when(HclExpr()):
with else_when(U(1)):
c <<= a + b
with otherwise():
c <<= a + b
@ -119,7 +126,7 @@ def test_branch_syntax_error5():
b = Wire(U.w(8))
c = Wire(U.w(8))
with when(HclExpr()):
with when(U(0)):
b <<= a + c
c <<= a + b
with otherwise():

View File

@ -27,13 +27,13 @@ def test_io_no_wrap_io():
io = IO(i=HclType())
with pytest.raises(ExprError, match='^.*Input.*Output.*$'):
class A(Module): # noqa: F811
class A(Module):
io = IO(
i=HclType(),
o=Output(HclType()))
with pytest.raises(ExprError, match='^.*Input.*Output.*$'):
class A(Module): # noqa: F811
class A(Module):
io = IO(
i=Input(HclType()),
o=HclType())

View File

@ -1,6 +1,6 @@
import pytest
from py_hcl.core.module.error import ModuleError
from py_hcl.core.module_factory.error import ModuleError
from py_hcl.core.expr import HclExpr
from py_hcl.dsl.expr.io import IO, Input
from py_hcl.dsl.module import Module

View File

@ -15,6 +15,6 @@ def test_statement():
c <<= a + b
s = A.packed_module.top_scope.statements
s = A.packed_module.top_statement.statements
assert len(s) == 1
assert isinstance(s[0]['statement'], Connect)
assert isinstance(s[0].statement, Connect)

View File

@ -24,12 +24,12 @@ def test_circuit_basis():
ct = DefCircuit("m1", [m1, m2])
assert check(ct)
serialize_stmt_equal(ct, 'circuit m1 :\n'
' module m1 :\n'
' module_factory m1 :\n'
' output p : UInt<8>\n'
'\n'
' p <= UInt<8>("2")\n'
'\n'
' module m2 :\n'
' module_factory m2 :\n'
' input b : UInt<8>\n'
' output a : UInt<8>\n'
'\n'

View File

@ -14,7 +14,7 @@ def test_module_basis():
mod = DefModule("m", [OutputPort("p", uw(8))],
Connect(n("p", uw(8)), u(2, w(8))))
assert check(mod)
serialize_stmt_equal(mod, 'module m :\n'
serialize_stmt_equal(mod, 'module_factory m :\n'
' output p : UInt<8>\n'
'\n'
' p <= UInt<8>("2")')
@ -27,7 +27,7 @@ def test_module_basis():
Connect(n("a", uw(8)), n("b", uw(8))))
]))
assert check(mod)
serialize_stmt_equal(mod, 'module m :\n'
serialize_stmt_equal(mod, 'module_factory m :\n'
' input b : UInt<8>\n'
' output a : UInt<8>\n'
'\n'