diff --git a/py_hcl/firrtl_ir/literal.py b/py_hcl/firrtl_ir/literal.py index fcaf9f7..1028c43 100644 --- a/py_hcl/firrtl_ir/literal.py +++ b/py_hcl/firrtl_ir/literal.py @@ -1,3 +1,4 @@ +from .tpe import UIntType, SIntType from .expression import Expression from .utils import serialize_str @@ -6,12 +7,13 @@ class UIntLiteral(Expression): def __init__(self, value, width): self.value = value self.width = width + self.tpe = UIntType(width) def serialize(self, output): output.write(b"UInt") self.width.serialize(output) output.write(b'("') - output.write(serialize_str(hex(self.value)[2:])) + output.write(serialize_str(hex(self.value).replace("0x", ""))) output.write(b'")') @@ -19,10 +21,11 @@ class SIntLiteral(Expression): def __init__(self, value, width): self.value = value self.width = width + self.tpe = SIntType(width) def serialize(self, output): output.write(b"SInt") self.width.serialize(output) output.write(b'("') - output.write(serialize_str(hex(self.value)[2:])) + output.write(serialize_str(hex(self.value).replace("0x", ""))) output.write(b'")') diff --git a/py_hcl/firrtl_ir/prim_call.py b/py_hcl/firrtl_ir/prim_call.py deleted file mode 100644 index 2234eb9..0000000 --- a/py_hcl/firrtl_ir/prim_call.py +++ /dev/null @@ -1,30 +0,0 @@ -from .utils import serialize_num, serialize_str -from .expression import Expression - - -class PrimCall(Expression): - def __init__(self, prim_op, ir_args, const_args, tpe): - self.prim_op = prim_op - self.ir_args = ir_args - self.const_args = const_args - self.tpe = tpe - - def serialize(self, output): - output.write(serialize_str(self.prim_op)) - output.write(b"(") - - comma_cnt = len(self.ir_args) + len(self.const_args) - 1 - - for ir_arg in self.ir_args: - ir_arg.serialize(output) - if comma_cnt > 0: - comma_cnt -= 1 - output.write(b", ") - - for const_arg in self.const_args: - output.write(serialize_num(const_arg)) - if comma_cnt > 0: - comma_cnt -= 1 - output.write(b", ") - - output.write(b")") diff --git a/py_hcl/firrtl_ir/prim_ops.py b/py_hcl/firrtl_ir/prim_ops.py index cf3967a..d5d10ce 100644 --- a/py_hcl/firrtl_ir/prim_ops.py +++ b/py_hcl/firrtl_ir/prim_ops.py @@ -1,25 +1,524 @@ -Add = 'add' -Sub = 'sub' -Lt = 'lt' -Leq = 'leq' -Gt = 'gt' -Geq = 'geq' -Eq = 'eq' -Neq = 'neq' -And = 'and' -Or = 'or' -Xor = 'xor' +from .expression import Expression +from .utils import serialize_num +from .tpe import UIntType, SIntType, ClockType -Mul = 'mul' -Div = 'div' -Rem = 'rem' -AsUInt = 'asUInt' -AsSInt = 'asSInt' -Shl = 'shl' -Shr = 'shr' -Dshl = 'dshl' -Dshr = 'dshr' -Neg = 'neg' -Not = 'not' -Cat = 'cat' -Bits = 'bits' + +def type_in(obj, *types): + for t in types: + if isinstance(obj, t): + return True + return False + + +def all_the_same(*objects): + t = objects[0] + for o in objects[1:]: + if o != t: + return False + return True + + +def check_all_same_uint_sint(*types): + for t in types: + if not type_in(t, UIntType, SIntType): + return False + + return all_the_same(*list(map(type, types))) + + +class Add(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = max(self.args[0].tpe.width.width, + self.args[1].tpe.width.width) + 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"add(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Sub(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = max(self.args[0].tpe.width.width, + self.args[1].tpe.width.width) + 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"sub(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Mul(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = \ + self.args[0].tpe.width.width + self.args[1].tpe.width.width + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"mul(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Div(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = self.args[0].tpe.width.width + if type_in(self.args[0].tpe, SIntType): + expected_type_width += 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"div(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Rem(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = min(self.args[0].tpe.width.width, + self.args[1].tpe.width.width) + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"rem(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class _Comparison(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe): + return False + + if not type_in(self.tpe, UIntType): + return False + + if self.tpe.width.width != 1: + return False + + return True + + def serialize(self, output): + output.write(b"(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Lt(_Comparison): + def serialize(self, output): + output.write(b"lt") + super().serialize(output) + + +class Leq(_Comparison): + def serialize(self, output): + output.write(b"leq") + super().serialize(output) + + +class Gt(_Comparison): + def serialize(self, output): + output.write(b"gt") + super().serialize(output) + + +class Geq(_Comparison): + def serialize(self, output): + output.write(b"geq") + super().serialize(output) + + +class Eq(_Comparison): + def serialize(self, output): + output.write(b"eq") + super().serialize(output) + + +class Neq(_Comparison): + def serialize(self, output): + output.write(b"neq") + super().serialize(output) + + +class _BinaryBit(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = max(self.args[0].tpe.width.width, + self.args[1].tpe.width.width) + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class And(_BinaryBit): + def serialize(self, output): + output.write(b"and") + super().serialize(output) + + +class Or(_BinaryBit): + def serialize(self, output): + output.write(b"or") + super().serialize(output) + + +class Xor(_BinaryBit): + def serialize(self, output): + output.write(b"xor") + super().serialize(output) + + +class Not(Expression): + def __init__(self, arg, tpe): + self.arg = arg + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.arg.tpe, + self.tpe): + return False + + expected_type_width = self.arg.tpe.width.width + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"not(") + self.arg.serialize(output) + output.write(b")") + + +class Neg(Expression): + def __init__(self, arg, tpe): + self.arg = arg + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.arg.tpe, + self.tpe): + return False + + expected_type_width = self.arg.tpe.width.width + 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"neg(") + self.arg.serialize(output) + output.write(b")") + + +class Cat(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.args[1].tpe, + self.tpe): + return False + + expected_type_width = \ + self.args[0].tpe.width.width + self.args[1].tpe.width.width + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"cat(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Bits(Expression): + def __init__(self, ir_arg, const_args, tpe): + self.ir_arg = ir_arg + self.const_args = const_args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.ir_arg.tpe): + return False + + if not type_in(self.tpe, UIntType): + return False + + if not \ + self.tpe.width.width >= \ + self.const_args[0] >= \ + self.const_args[1] >= 0: + return False + + expected_type_width = self.const_args[0] - self.const_args[1] + 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"bits(") + self.ir_arg.serialize(output) + output.write(b", ") + output.write(serialize_num(self.const_args[0])) + output.write(b", ") + output.write(serialize_num(self.const_args[1])) + output.write(b")") + + +class AsUInt(Expression): + def __init__(self, arg, tpe): + self.arg = arg + self.tpe = tpe + + def check_type(self): + if not type_in(self.arg.tpe, UIntType, SIntType, ClockType): + return False + + if not type_in(self.tpe, UIntType): + return False + + expected_type_width = self.arg.width.width + if type_in(self.tpe, ClockType): + expected_type_width = 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"asUInt(") + self.arg.serialize(output) + output.write(b")") + + +class AsSInt(Expression): + def __init__(self, arg, tpe): + self.arg = arg + self.tpe = tpe + + def check_type(self): + if not type_in(self.arg.tpe, UIntType, SIntType, ClockType): + return False + + if not type_in(self.tpe, SIntType): + return False + + expected_type_width = self.arg.width.width + if type_in(self.tpe, ClockType): + expected_type_width = 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"asSInt(") + self.arg.serialize(output) + output.write(b")") + + +class Shl(Expression): + def __init__(self, ir_arg, const_arg, tpe): + self.ir_arg = ir_arg + self.const_arg = const_arg + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.ir_arg.tpe, self.tpe): + return False + + expected_type_width = self.ir_arg.width.width + self.const_arg + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"shl(") + self.ir_arg.serialize(output) + output.write(b", ") + output.write(serialize_num(self.const_arg)) + output.write(b")") + + +class Shr(Expression): + def __init__(self, ir_arg, const_arg, tpe): + self.ir_arg = ir_arg + self.const_arg = const_arg + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.ir_arg.tpe, self.tpe): + return False + + expected_type_width = self.ir_arg.width.width - self.const_arg + expected_type_width = max(expected_type_width, 1) + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"shr(") + self.ir_arg.serialize(output) + output.write(b", ") + output.write(serialize_num(self.const_arg)) + output.write(b")") + + +class Dshl(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.tpe): + return False + + if type_in(self.args[1].tpe, UIntType): + return False + + expected_type_width = \ + self.args[0].width.width + 2 ** self.args[1].width.width - 1 + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"dshl(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") + + +class Dshr(Expression): + def __init__(self, args, tpe): + self.args = args + self.tpe = tpe + + def check_type(self): + if not check_all_same_uint_sint(self.args[0].tpe, + self.tpe): + return False + + if type_in(self.args[1].tpe, UIntType): + return False + + expected_type_width = self.args[0].width.width + if self.tpe.width.width != expected_type_width: + return False + + return True + + def serialize(self, output): + output.write(b"dshr(") + self.args[0].serialize(output) + output.write(b", ") + self.args[1].serialize(output) + output.write(b")") diff --git a/py_hcl/firrtl_ir/shortcuts.py b/py_hcl/firrtl_ir/shortcuts.py new file mode 100644 index 0000000..a26fed1 --- /dev/null +++ b/py_hcl/firrtl_ir/shortcuts.py @@ -0,0 +1,38 @@ +from .field import Field +from .reference import Reference +from .literal import UIntLiteral, SIntLiteral +from .tpe import SIntType, UIntType, VectorType, BundleType +from .width import IntWidth + + +def sw(width): + return SIntType(IntWidth(width)) + + +def uw(width): + return UIntType(IntWidth(width)) + + +def w(width): + return IntWidth(width) + + +def u(value, width): + return UIntLiteral(value, width) + + +def s(value, width): + return SIntLiteral(value, width) + + +def n(name, tpe): + return Reference(name, tpe) + + +def vec(tpe, size): + return VectorType(tpe, size) + + +def bdl(**field): + fields = [Field(k, field[k][0], field[k][1]) for k in field] + return BundleType(fields) diff --git a/py_hcl/firrtl_ir/utils.py b/py_hcl/firrtl_ir/utils.py index 0abe1f5..3e2f9a4 100644 --- a/py_hcl/firrtl_ir/utils.py +++ b/py_hcl/firrtl_ir/utils.py @@ -4,3 +4,7 @@ def serialize_num(num): def serialize_str(s): return bytes(s, 'utf-8') + + +def signed_num_bin_len(num): + return len(bin(abs(num))) - 1 diff --git a/tests/test_firrtl_ir/test_prim_op.py b/tests/test_firrtl_ir/test_prim_op.py deleted file mode 100644 index bff8582..0000000 --- a/tests/test_firrtl_ir/test_prim_op.py +++ /dev/null @@ -1,36 +0,0 @@ -from py_hcl.firrtl_ir.literal import UIntLiteral -from py_hcl.firrtl_ir.prim_call import PrimCall -from py_hcl.firrtl_ir.prim_ops import Add, Sub, Lt, Leq, \ - Gt, Geq, Eq, Neq, And, Or, Xor -from py_hcl.firrtl_ir.reference import Reference -from py_hcl.firrtl_ir.tpe import UIntType -from py_hcl.firrtl_ir.width import IntWidth -from .utils import serialize_equal - - -def test_prim_op_add_sub_lt_leq_gt_geq_eq_neq_and_or_xor(): - tpe = UIntType(IntWidth(8)) - ops = [ - Add, - Sub, - Lt, - Leq, - Gt, - Geq, - Eq, - Neq, - And, - Or, - Xor, - ] - cases = [ - ([Reference("a", tpe), Reference("b", tpe)], tpe, - lambda op: op + '(a, b)'), - ([UIntLiteral(2, IntWidth(8)), UIntLiteral(4, IntWidth(8))], tpe, - lambda op: op + '(UInt<8>("2"), UInt<8>("4"))'), - ([UIntLiteral(2, IntWidth(8)), Reference("a", tpe)], tpe, - lambda op: op + '(UInt<8>("2"), a)'), - ] - for case in cases: - for o in ops: - serialize_equal(PrimCall(o, case[0], [], case[1]), case[2](o)) diff --git a/py_hcl/firrtl_ir/prim_opps.py b/tests/test_firrtl_ir/test_prim_ops/__init__.py similarity index 100% rename from py_hcl/firrtl_ir/prim_opps.py rename to tests/test_firrtl_ir/test_prim_ops/__init__.py diff --git a/tests/test_firrtl_ir/test_prim_ops/test_add.py b/tests/test_firrtl_ir/test_prim_ops/test_add.py new file mode 100644 index 0000000..1268fbc --- /dev/null +++ b/tests/test_firrtl_ir/test_prim_ops/test_add.py @@ -0,0 +1,72 @@ +from py_hcl.firrtl_ir.prim_ops import Add +from py_hcl.firrtl_ir.shortcuts import w, uw, u, s, sw, n, vec, bdl +from py_hcl.firrtl_ir.tpe import UnknownType +from ..utils import serialize_equal + + +def test_basis(): + args = [u(20, w(5)), u(15, w(4))] + add = Add(args, uw(6)) + assert add.check_type() + serialize_equal(add, 'add(UInt<5>("14"), UInt<4>("f"))') + + args = [n("a", uw(6)), u(15, w(4))] + add = Add(args, uw(7)) + assert add.check_type() + serialize_equal(add, 'add(a, UInt<4>("f"))') + + args = [n("a", uw(6)), n("b", uw(6))] + add = Add(args, uw(7)) + assert add.check_type() + serialize_equal(add, 'add(a, b)') + + args = [s(20, w(6)), s(15, w(5))] + add = Add(args, sw(7)) + assert add.check_type() + serialize_equal(add, 'add(SInt<6>("14"), SInt<5>("f"))') + + args = [n("a", sw(6)), s(-15, w(5))] + add = Add(args, sw(7)) + assert add.check_type() + serialize_equal(add, 'add(a, SInt<5>("-f"))') + + args = [n("a", sw(6)), n("b", sw(6))] + add = Add(args, sw(7)) + assert add.check_type() + serialize_equal(add, 'add(a, b)') + + +def test_type_is_wrong(): + args = [n("a", UnknownType()), n("b", sw(6))] + add = Add(args, sw(7)) + assert not add.check_type() + + args = [n("a", uw(6)), n("b", UnknownType())] + add = Add(args, sw(7)) + assert not add.check_type() + + args = [n("a", vec(sw(10), 8)), n("b", sw(6))] + add = Add(args, sw(7)) + assert not add.check_type() + + args = [n("a", uw(6)), n("b", bdl(a=[uw(20), True]))] + add = Add(args, sw(7)) + assert not add.check_type() + + +def test_width_is_wrong(): + args = [n("a", sw(6)), n("b", sw(6))] + add = Add(args, sw(6)) + assert not add.check_type() + + args = [n("a", uw(6)), n("b", uw(6))] + add = Add(args, uw(6)) + assert not add.check_type() + + args = [n("a", sw(6)), n("b", sw(6))] + add = Add(args, sw(1)) + assert not add.check_type() + + args = [n("a", uw(6)), n("b", uw(6))] + add = Add(args, uw(10)) + assert not add.check_type() diff --git a/tests/test_firrtl_ir/test_type.py b/tests/test_firrtl_ir/test_type.py index dc6656f..daa10cc 100644 --- a/tests/test_firrtl_ir/test_type.py +++ b/tests/test_firrtl_ir/test_type.py @@ -70,330 +70,3 @@ def test_bundle_type(): ])) ]) serialize_equal(bd, "{l1 : {l2 : {flip l3 : UInt<8>}, vt : UInt<8>[16]}}") - - -def test_type_eq(): - assert UnknownType().type_eq( - UnknownType() - ) - assert ClockType().type_eq( - ClockType() - ) - assert UIntType(IntWidth(10)).type_eq( - UIntType(IntWidth(10)) - ) - assert UIntType(UnknownWidth()).type_eq( - UIntType(UnknownWidth()) - ) - assert SIntType(IntWidth(10)).type_eq( - SIntType(IntWidth(10)) - ) - assert SIntType(UnknownWidth()).type_eq( - SIntType(UnknownWidth()) - ) - assert VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(UIntType(IntWidth(10)), 8) - ) - assert BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - -def test_type_neq(): - assert not UnknownType().type_eq( - ClockType() - ) - assert not UnknownType().type_eq(UIntType( - IntWidth(10)) - ) - assert not UnknownType().type_eq(UIntType( - UnknownWidth()) - ) - assert not UnknownType().type_eq(SIntType( - IntWidth(10)) - ) - assert not UnknownType().type_eq(SIntType( - UnknownWidth()) - ) - assert not UnknownType().type_eq(VectorType( - UIntType(IntWidth(10)), 8) - ) - assert not UnknownType().type_eq(VectorType( - UIntType(UnknownWidth()), 8) - ) - assert not UnknownType().type_eq(VectorType( - SIntType(IntWidth(10)), 8) - ) - assert not UnknownType().type_eq(VectorType( - SIntType(UnknownWidth()), 8) - ) - assert not UnknownType().type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not ClockType().type_eq( - UIntType(IntWidth(10)) - ) - assert not ClockType().type_eq( - UIntType(UnknownWidth()) - ) - assert not ClockType().type_eq( - SIntType(IntWidth(10)) - ) - assert not ClockType().type_eq( - SIntType(UnknownWidth()) - ) - assert not ClockType().type_eq( - VectorType(UIntType(IntWidth(10)), 8) - - ) - assert not ClockType().type_eq( - VectorType(UIntType(UnknownWidth()), 8) - ) - assert not ClockType().type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not ClockType().type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not ClockType().type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not UIntType(IntWidth(10)).type_eq( - UIntType(UnknownWidth()) - ) - assert not UIntType(UnknownWidth()).type_eq( - UIntType(IntWidth(10)) - ) - assert not UIntType(IntWidth(10)).type_eq( - UIntType(IntWidth(8)) - ) - assert not UIntType(IntWidth(10)).type_eq( - SIntType(IntWidth(10)) - ) - assert not UIntType(IntWidth(10)).type_eq( - SIntType(UnknownWidth()) - ) - assert not UIntType(IntWidth(10)).type_eq( - VectorType(UIntType(IntWidth(10)), 8) - ) - assert not UIntType(IntWidth(10)).type_eq( - VectorType(UIntType(UnknownWidth()), 8) - ) - assert not UIntType(IntWidth(10)).type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not UIntType(IntWidth(10)).type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not UIntType(IntWidth(10)).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not UIntType(UnknownWidth()).type_eq( - SIntType(IntWidth(10)) - ) - assert not UIntType(UnknownWidth()).type_eq( - SIntType(UnknownWidth()) - ) - assert not UIntType(UnknownWidth()).type_eq( - VectorType(UIntType(IntWidth(10)), 8) - ) - assert not UIntType(UnknownWidth()).type_eq( - VectorType(UIntType(UnknownWidth()), 8) - ) - assert not UIntType(UnknownWidth()).type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not UIntType(UnknownWidth()).type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not UIntType(UnknownWidth()).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not SIntType(IntWidth(10)).type_eq( - SIntType(UnknownWidth()) - ) - assert not SIntType(IntWidth(10)).type_eq( - SIntType(IntWidth(8)) - ) - assert not SIntType(IntWidth(10)).type_eq( - VectorType(UIntType(IntWidth(10)), 8) - ) - assert not SIntType(IntWidth(10)).type_eq( - VectorType(UIntType(UnknownWidth()), 8) - - ) - assert not SIntType(IntWidth(10)).type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not SIntType(IntWidth(10)).type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not SIntType(IntWidth(10)).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not SIntType(UnknownWidth()).type_eq( - VectorType(UIntType(IntWidth(10)), 8) - ) - assert not SIntType(UnknownWidth()).type_eq( - VectorType(UIntType(UnknownWidth()), 8) - ) - assert not SIntType(UnknownWidth()).type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not SIntType(UnknownWidth()).type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not SIntType(UnknownWidth()).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(UIntType(IntWidth(8)), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(UIntType(IntWidth(10)), 6) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(UIntType(UnknownWidth()), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(SIntType(IntWidth(10)), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(SIntType(IntWidth(8)), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(SIntType(IntWidth(10)), 6) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(SIntType(UnknownWidth()), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(UnknownType(), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - VectorType(ClockType(), 8) - ) - assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("c", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("b", VectorType(UIntType(IntWidth(10)), 8)), - Field("a", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(8)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 6)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(UIntType(UnknownWidth()), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(SIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", UnknownType()), - Field("b", UIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("b", UnknownType()), - Field("a", SIntType(IntWidth(10))) - ]) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - UIntType(IntWidth(10)) - ) - assert not BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8)), - Field("b", UIntType(IntWidth(10))) - ]).type_eq( - BundleType([ - Field("a", VectorType(UIntType(IntWidth(10)), 8), True), - Field("b", UIntType(IntWidth(10)), True) - ]) - ) diff --git a/tests/test_firrtl_ir/test_type_equal.py b/tests/test_firrtl_ir/test_type_equal.py new file mode 100644 index 0000000..bacf544 --- /dev/null +++ b/tests/test_firrtl_ir/test_type_equal.py @@ -0,0 +1,331 @@ +from py_hcl.firrtl_ir.field import Field +from py_hcl.firrtl_ir.tpe import UnknownType, ClockType, \ + UIntType, SIntType, VectorType, BundleType +from py_hcl.firrtl_ir.width import IntWidth, UnknownWidth + + +def test_type_eq(): + assert UnknownType().type_eq( + UnknownType() + ) + assert ClockType().type_eq( + ClockType() + ) + assert UIntType(IntWidth(10)).type_eq( + UIntType(IntWidth(10)) + ) + assert UIntType(UnknownWidth()).type_eq( + UIntType(UnknownWidth()) + ) + assert SIntType(IntWidth(10)).type_eq( + SIntType(IntWidth(10)) + ) + assert SIntType(UnknownWidth()).type_eq( + SIntType(UnknownWidth()) + ) + assert VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(UIntType(IntWidth(10)), 8) + ) + assert BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + +def test_type_neq(): + assert not UnknownType().type_eq( + ClockType() + ) + assert not UnknownType().type_eq(UIntType( + IntWidth(10)) + ) + assert not UnknownType().type_eq(UIntType( + UnknownWidth()) + ) + assert not UnknownType().type_eq(SIntType( + IntWidth(10)) + ) + assert not UnknownType().type_eq(SIntType( + UnknownWidth()) + ) + assert not UnknownType().type_eq(VectorType( + UIntType(IntWidth(10)), 8) + ) + assert not UnknownType().type_eq(VectorType( + UIntType(UnknownWidth()), 8) + ) + assert not UnknownType().type_eq(VectorType( + SIntType(IntWidth(10)), 8) + ) + assert not UnknownType().type_eq(VectorType( + SIntType(UnknownWidth()), 8) + ) + assert not UnknownType().type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not ClockType().type_eq( + UIntType(IntWidth(10)) + ) + assert not ClockType().type_eq( + UIntType(UnknownWidth()) + ) + assert not ClockType().type_eq( + SIntType(IntWidth(10)) + ) + assert not ClockType().type_eq( + SIntType(UnknownWidth()) + ) + assert not ClockType().type_eq( + VectorType(UIntType(IntWidth(10)), 8) + + ) + assert not ClockType().type_eq( + VectorType(UIntType(UnknownWidth()), 8) + ) + assert not ClockType().type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not ClockType().type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not ClockType().type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not UIntType(IntWidth(10)).type_eq( + UIntType(UnknownWidth()) + ) + assert not UIntType(UnknownWidth()).type_eq( + UIntType(IntWidth(10)) + ) + assert not UIntType(IntWidth(10)).type_eq( + UIntType(IntWidth(8)) + ) + assert not UIntType(IntWidth(10)).type_eq( + SIntType(IntWidth(10)) + ) + assert not UIntType(IntWidth(10)).type_eq( + SIntType(UnknownWidth()) + ) + assert not UIntType(IntWidth(10)).type_eq( + VectorType(UIntType(IntWidth(10)), 8) + ) + assert not UIntType(IntWidth(10)).type_eq( + VectorType(UIntType(UnknownWidth()), 8) + ) + assert not UIntType(IntWidth(10)).type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not UIntType(IntWidth(10)).type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not UIntType(IntWidth(10)).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not UIntType(UnknownWidth()).type_eq( + SIntType(IntWidth(10)) + ) + assert not UIntType(UnknownWidth()).type_eq( + SIntType(UnknownWidth()) + ) + assert not UIntType(UnknownWidth()).type_eq( + VectorType(UIntType(IntWidth(10)), 8) + ) + assert not UIntType(UnknownWidth()).type_eq( + VectorType(UIntType(UnknownWidth()), 8) + ) + assert not UIntType(UnknownWidth()).type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not UIntType(UnknownWidth()).type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not UIntType(UnknownWidth()).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not SIntType(IntWidth(10)).type_eq( + SIntType(UnknownWidth()) + ) + assert not SIntType(IntWidth(10)).type_eq( + SIntType(IntWidth(8)) + ) + assert not SIntType(IntWidth(10)).type_eq( + VectorType(UIntType(IntWidth(10)), 8) + ) + assert not SIntType(IntWidth(10)).type_eq( + VectorType(UIntType(UnknownWidth()), 8) + + ) + assert not SIntType(IntWidth(10)).type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not SIntType(IntWidth(10)).type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not SIntType(IntWidth(10)).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not SIntType(UnknownWidth()).type_eq( + VectorType(UIntType(IntWidth(10)), 8) + ) + assert not SIntType(UnknownWidth()).type_eq( + VectorType(UIntType(UnknownWidth()), 8) + ) + assert not SIntType(UnknownWidth()).type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not SIntType(UnknownWidth()).type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not SIntType(UnknownWidth()).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(UIntType(IntWidth(8)), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(UIntType(IntWidth(10)), 6) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(UIntType(UnknownWidth()), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(SIntType(IntWidth(10)), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(SIntType(IntWidth(8)), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(SIntType(IntWidth(10)), 6) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(SIntType(UnknownWidth()), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(UnknownType(), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + VectorType(ClockType(), 8) + ) + assert not VectorType(UIntType(IntWidth(10)), 8).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("c", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("b", VectorType(UIntType(IntWidth(10)), 8)), + Field("a", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(8)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 6)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(UIntType(UnknownWidth()), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(SIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", UnknownType()), + Field("b", UIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("b", UnknownType()), + Field("a", SIntType(IntWidth(10))) + ]) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + UIntType(IntWidth(10)) + ) + assert not BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8)), + Field("b", UIntType(IntWidth(10))) + ]).type_eq( + BundleType([ + Field("a", VectorType(UIntType(IntWidth(10)), 8), True), + Field("b", UIntType(IntWidth(10)), True) + ]) + )