firrtl_ir: add some ops (#9)

This commit is contained in:
Gaufoo 2019-11-15 22:54:21 +08:00 committed by GitHub
parent 6e48d60b63
commit 9df6f0e851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 973 additions and 419 deletions

View File

@ -1,3 +1,4 @@
from .tpe import UIntType, SIntType
from .expression import Expression
from .utils import serialize_str
@ -6,12 +7,13 @@ class UIntLiteral(Expression):
def __init__(self, value, width):
self.value = value
self.width = width
self.tpe = UIntType(width)
def serialize(self, output):
output.write(b"UInt")
self.width.serialize(output)
output.write(b'("')
output.write(serialize_str(hex(self.value)[2:]))
output.write(serialize_str(hex(self.value).replace("0x", "")))
output.write(b'")')
@ -19,10 +21,11 @@ class SIntLiteral(Expression):
def __init__(self, value, width):
self.value = value
self.width = width
self.tpe = SIntType(width)
def serialize(self, output):
output.write(b"SInt")
self.width.serialize(output)
output.write(b'("')
output.write(serialize_str(hex(self.value)[2:]))
output.write(serialize_str(hex(self.value).replace("0x", "")))
output.write(b'")')

View File

@ -1,30 +0,0 @@
from .utils import serialize_num, serialize_str
from .expression import Expression
class PrimCall(Expression):
def __init__(self, prim_op, ir_args, const_args, tpe):
self.prim_op = prim_op
self.ir_args = ir_args
self.const_args = const_args
self.tpe = tpe
def serialize(self, output):
output.write(serialize_str(self.prim_op))
output.write(b"(")
comma_cnt = len(self.ir_args) + len(self.const_args) - 1
for ir_arg in self.ir_args:
ir_arg.serialize(output)
if comma_cnt > 0:
comma_cnt -= 1
output.write(b", ")
for const_arg in self.const_args:
output.write(serialize_num(const_arg))
if comma_cnt > 0:
comma_cnt -= 1
output.write(b", ")
output.write(b")")

View File

@ -1,25 +1,524 @@
Add = 'add'
Sub = 'sub'
Lt = 'lt'
Leq = 'leq'
Gt = 'gt'
Geq = 'geq'
Eq = 'eq'
Neq = 'neq'
And = 'and'
Or = 'or'
Xor = 'xor'
from .expression import Expression
from .utils import serialize_num
from .tpe import UIntType, SIntType, ClockType
Mul = 'mul'
Div = 'div'
Rem = 'rem'
AsUInt = 'asUInt'
AsSInt = 'asSInt'
Shl = 'shl'
Shr = 'shr'
Dshl = 'dshl'
Dshr = 'dshr'
Neg = 'neg'
Not = 'not'
Cat = 'cat'
Bits = 'bits'
def type_in(obj, *types):
for t in types:
if isinstance(obj, t):
return True
return False
def all_the_same(*objects):
t = objects[0]
for o in objects[1:]:
if o != t:
return False
return True
def check_all_same_uint_sint(*types):
for t in types:
if not type_in(t, UIntType, SIntType):
return False
return all_the_same(*list(map(type, types)))
class Add(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = max(self.args[0].tpe.width.width,
self.args[1].tpe.width.width) + 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"add(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Sub(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = max(self.args[0].tpe.width.width,
self.args[1].tpe.width.width) + 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"sub(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Mul(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = \
self.args[0].tpe.width.width + self.args[1].tpe.width.width
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"mul(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Div(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = self.args[0].tpe.width.width
if type_in(self.args[0].tpe, SIntType):
expected_type_width += 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"div(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Rem(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = min(self.args[0].tpe.width.width,
self.args[1].tpe.width.width)
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"rem(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class _Comparison(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe):
return False
if not type_in(self.tpe, UIntType):
return False
if self.tpe.width.width != 1:
return False
return True
def serialize(self, output):
output.write(b"(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Lt(_Comparison):
def serialize(self, output):
output.write(b"lt")
super().serialize(output)
class Leq(_Comparison):
def serialize(self, output):
output.write(b"leq")
super().serialize(output)
class Gt(_Comparison):
def serialize(self, output):
output.write(b"gt")
super().serialize(output)
class Geq(_Comparison):
def serialize(self, output):
output.write(b"geq")
super().serialize(output)
class Eq(_Comparison):
def serialize(self, output):
output.write(b"eq")
super().serialize(output)
class Neq(_Comparison):
def serialize(self, output):
output.write(b"neq")
super().serialize(output)
class _BinaryBit(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = max(self.args[0].tpe.width.width,
self.args[1].tpe.width.width)
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class And(_BinaryBit):
def serialize(self, output):
output.write(b"and")
super().serialize(output)
class Or(_BinaryBit):
def serialize(self, output):
output.write(b"or")
super().serialize(output)
class Xor(_BinaryBit):
def serialize(self, output):
output.write(b"xor")
super().serialize(output)
class Not(Expression):
def __init__(self, arg, tpe):
self.arg = arg
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.arg.tpe,
self.tpe):
return False
expected_type_width = self.arg.tpe.width.width
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"not(")
self.arg.serialize(output)
output.write(b")")
class Neg(Expression):
def __init__(self, arg, tpe):
self.arg = arg
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.arg.tpe,
self.tpe):
return False
expected_type_width = self.arg.tpe.width.width + 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"neg(")
self.arg.serialize(output)
output.write(b")")
class Cat(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.args[1].tpe,
self.tpe):
return False
expected_type_width = \
self.args[0].tpe.width.width + self.args[1].tpe.width.width
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"cat(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Bits(Expression):
def __init__(self, ir_arg, const_args, tpe):
self.ir_arg = ir_arg
self.const_args = const_args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.ir_arg.tpe):
return False
if not type_in(self.tpe, UIntType):
return False
if not \
self.tpe.width.width >= \
self.const_args[0] >= \
self.const_args[1] >= 0:
return False
expected_type_width = self.const_args[0] - self.const_args[1] + 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"bits(")
self.ir_arg.serialize(output)
output.write(b", ")
output.write(serialize_num(self.const_args[0]))
output.write(b", ")
output.write(serialize_num(self.const_args[1]))
output.write(b")")
class AsUInt(Expression):
def __init__(self, arg, tpe):
self.arg = arg
self.tpe = tpe
def check_type(self):
if not type_in(self.arg.tpe, UIntType, SIntType, ClockType):
return False
if not type_in(self.tpe, UIntType):
return False
expected_type_width = self.arg.width.width
if type_in(self.tpe, ClockType):
expected_type_width = 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"asUInt(")
self.arg.serialize(output)
output.write(b")")
class AsSInt(Expression):
def __init__(self, arg, tpe):
self.arg = arg
self.tpe = tpe
def check_type(self):
if not type_in(self.arg.tpe, UIntType, SIntType, ClockType):
return False
if not type_in(self.tpe, SIntType):
return False
expected_type_width = self.arg.width.width
if type_in(self.tpe, ClockType):
expected_type_width = 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"asSInt(")
self.arg.serialize(output)
output.write(b")")
class Shl(Expression):
def __init__(self, ir_arg, const_arg, tpe):
self.ir_arg = ir_arg
self.const_arg = const_arg
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.ir_arg.tpe, self.tpe):
return False
expected_type_width = self.ir_arg.width.width + self.const_arg
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"shl(")
self.ir_arg.serialize(output)
output.write(b", ")
output.write(serialize_num(self.const_arg))
output.write(b")")
class Shr(Expression):
def __init__(self, ir_arg, const_arg, tpe):
self.ir_arg = ir_arg
self.const_arg = const_arg
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.ir_arg.tpe, self.tpe):
return False
expected_type_width = self.ir_arg.width.width - self.const_arg
expected_type_width = max(expected_type_width, 1)
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"shr(")
self.ir_arg.serialize(output)
output.write(b", ")
output.write(serialize_num(self.const_arg))
output.write(b")")
class Dshl(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.tpe):
return False
if type_in(self.args[1].tpe, UIntType):
return False
expected_type_width = \
self.args[0].width.width + 2 ** self.args[1].width.width - 1
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"dshl(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")
class Dshr(Expression):
def __init__(self, args, tpe):
self.args = args
self.tpe = tpe
def check_type(self):
if not check_all_same_uint_sint(self.args[0].tpe,
self.tpe):
return False
if type_in(self.args[1].tpe, UIntType):
return False
expected_type_width = self.args[0].width.width
if self.tpe.width.width != expected_type_width:
return False
return True
def serialize(self, output):
output.write(b"dshr(")
self.args[0].serialize(output)
output.write(b", ")
self.args[1].serialize(output)
output.write(b")")

View File

@ -0,0 +1,38 @@
from .field import Field
from .reference import Reference
from .literal import UIntLiteral, SIntLiteral
from .tpe import SIntType, UIntType, VectorType, BundleType
from .width import IntWidth
def sw(width):
return SIntType(IntWidth(width))
def uw(width):
return UIntType(IntWidth(width))
def w(width):
return IntWidth(width)
def u(value, width):
return UIntLiteral(value, width)
def s(value, width):
return SIntLiteral(value, width)
def n(name, tpe):
return Reference(name, tpe)
def vec(tpe, size):
return VectorType(tpe, size)
def bdl(**field):
fields = [Field(k, field[k][0], field[k][1]) for k in field]
return BundleType(fields)

View File

@ -4,3 +4,7 @@ def serialize_num(num):
def serialize_str(s):
return bytes(s, 'utf-8')
def signed_num_bin_len(num):
return len(bin(abs(num))) - 1

View File

@ -1,36 +0,0 @@
from py_hcl.firrtl_ir.literal import UIntLiteral
from py_hcl.firrtl_ir.prim_call import PrimCall
from py_hcl.firrtl_ir.prim_ops import Add, Sub, Lt, Leq, \
Gt, Geq, Eq, Neq, And, Or, Xor
from py_hcl.firrtl_ir.reference import Reference
from py_hcl.firrtl_ir.tpe import UIntType
from py_hcl.firrtl_ir.width import IntWidth
from .utils import serialize_equal
def test_prim_op_add_sub_lt_leq_gt_geq_eq_neq_and_or_xor():
tpe = UIntType(IntWidth(8))
ops = [
Add,
Sub,
Lt,
Leq,
Gt,
Geq,
Eq,
Neq,
And,
Or,
Xor,
]
cases = [
([Reference("a", tpe), Reference("b", tpe)], tpe,
lambda op: op + '(a, b)'),
([UIntLiteral(2, IntWidth(8)), UIntLiteral(4, IntWidth(8))], tpe,
lambda op: op + '(UInt<8>("2"), UInt<8>("4"))'),
([UIntLiteral(2, IntWidth(8)), Reference("a", tpe)], tpe,
lambda op: op + '(UInt<8>("2"), a)'),
]
for case in cases:
for o in ops:
serialize_equal(PrimCall(o, case[0], [], case[1]), case[2](o))

View File

@ -0,0 +1,72 @@
from py_hcl.firrtl_ir.prim_ops import Add
from py_hcl.firrtl_ir.shortcuts import w, uw, u, s, sw, n, vec, bdl
from py_hcl.firrtl_ir.tpe import UnknownType
from ..utils import serialize_equal
def test_basis():
args = [u(20, w(5)), u(15, w(4))]
add = Add(args, uw(6))
assert add.check_type()
serialize_equal(add, 'add(UInt<5>("14"), UInt<4>("f"))')
args = [n("a", uw(6)), u(15, w(4))]
add = Add(args, uw(7))
assert add.check_type()
serialize_equal(add, 'add(a, UInt<4>("f"))')
args = [n("a", uw(6)), n("b", uw(6))]
add = Add(args, uw(7))
assert add.check_type()
serialize_equal(add, 'add(a, b)')
args = [s(20, w(6)), s(15, w(5))]
add = Add(args, sw(7))
assert add.check_type()
serialize_equal(add, 'add(SInt<6>("14"), SInt<5>("f"))')
args = [n("a", sw(6)), s(-15, w(5))]
add = Add(args, sw(7))
assert add.check_type()
serialize_equal(add, 'add(a, SInt<5>("-f"))')
args = [n("a", sw(6)), n("b", sw(6))]
add = Add(args, sw(7))
assert add.check_type()
serialize_equal(add, 'add(a, b)')
def test_type_is_wrong():
args = [n("a", UnknownType()), n("b", sw(6))]
add = Add(args, sw(7))
assert not add.check_type()
args = [n("a", uw(6)), n("b", UnknownType())]
add = Add(args, sw(7))
assert not add.check_type()
args = [n("a", vec(sw(10), 8)), n("b", sw(6))]
add = Add(args, sw(7))
assert not add.check_type()
args = [n("a", uw(6)), n("b", bdl(a=[uw(20), True]))]
add = Add(args, sw(7))
assert not add.check_type()
def test_width_is_wrong():
args = [n("a", sw(6)), n("b", sw(6))]
add = Add(args, sw(6))
assert not add.check_type()
args = [n("a", uw(6)), n("b", uw(6))]
add = Add(args, uw(6))
assert not add.check_type()
args = [n("a", sw(6)), n("b", sw(6))]
add = Add(args, sw(1))
assert not add.check_type()
args = [n("a", uw(6)), n("b", uw(6))]
add = Add(args, uw(10))
assert not add.check_type()

View File

@ -70,330 +70,3 @@ def test_bundle_type():
]))
])
serialize_equal(bd, "{l1 : {l2 : {flip l3 : UInt<8>}, vt : UInt<8>[16]}}")
def test_type_eq():
assert UnknownType().type_eq(
UnknownType()
)
assert ClockType().type_eq(
ClockType()
)
assert UIntType(IntWidth(10)).type_eq(
UIntType(IntWidth(10))
)
assert UIntType(UnknownWidth()).type_eq(
UIntType(UnknownWidth())
)
assert SIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(10))
)
assert SIntType(UnknownWidth()).type_eq(
SIntType(UnknownWidth())
)
assert VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
def test_type_neq():
assert not UnknownType().type_eq(
ClockType()
)
assert not UnknownType().type_eq(UIntType(
IntWidth(10))
)
assert not UnknownType().type_eq(UIntType(
UnknownWidth())
)
assert not UnknownType().type_eq(SIntType(
IntWidth(10))
)
assert not UnknownType().type_eq(SIntType(
UnknownWidth())
)
assert not UnknownType().type_eq(VectorType(
UIntType(IntWidth(10)), 8)
)
assert not UnknownType().type_eq(VectorType(
UIntType(UnknownWidth()), 8)
)
assert not UnknownType().type_eq(VectorType(
SIntType(IntWidth(10)), 8)
)
assert not UnknownType().type_eq(VectorType(
SIntType(UnknownWidth()), 8)
)
assert not UnknownType().type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not ClockType().type_eq(
UIntType(IntWidth(10))
)
assert not ClockType().type_eq(
UIntType(UnknownWidth())
)
assert not ClockType().type_eq(
SIntType(IntWidth(10))
)
assert not ClockType().type_eq(
SIntType(UnknownWidth())
)
assert not ClockType().type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not ClockType().type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not ClockType().type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not ClockType().type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not ClockType().type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not UIntType(IntWidth(10)).type_eq(
UIntType(UnknownWidth())
)
assert not UIntType(UnknownWidth()).type_eq(
UIntType(IntWidth(10))
)
assert not UIntType(IntWidth(10)).type_eq(
UIntType(IntWidth(8))
)
assert not UIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(10))
)
assert not UIntType(IntWidth(10)).type_eq(
SIntType(UnknownWidth())
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not UIntType(UnknownWidth()).type_eq(
SIntType(IntWidth(10))
)
assert not UIntType(UnknownWidth()).type_eq(
SIntType(UnknownWidth())
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not SIntType(IntWidth(10)).type_eq(
SIntType(UnknownWidth())
)
assert not SIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(8))
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(8)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(10)), 6)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(8)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(10)), 6)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UnknownType(), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(ClockType(), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("c", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("b", VectorType(UIntType(IntWidth(10)), 8)),
Field("a", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(8)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 6)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(UnknownWidth()), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(SIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", UnknownType()),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("b", UnknownType()),
Field("a", SIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
UIntType(IntWidth(10))
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8), True),
Field("b", UIntType(IntWidth(10)), True)
])
)

View File

@ -0,0 +1,331 @@
from py_hcl.firrtl_ir.field import Field
from py_hcl.firrtl_ir.tpe import UnknownType, ClockType, \
UIntType, SIntType, VectorType, BundleType
from py_hcl.firrtl_ir.width import IntWidth, UnknownWidth
def test_type_eq():
assert UnknownType().type_eq(
UnknownType()
)
assert ClockType().type_eq(
ClockType()
)
assert UIntType(IntWidth(10)).type_eq(
UIntType(IntWidth(10))
)
assert UIntType(UnknownWidth()).type_eq(
UIntType(UnknownWidth())
)
assert SIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(10))
)
assert SIntType(UnknownWidth()).type_eq(
SIntType(UnknownWidth())
)
assert VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
def test_type_neq():
assert not UnknownType().type_eq(
ClockType()
)
assert not UnknownType().type_eq(UIntType(
IntWidth(10))
)
assert not UnknownType().type_eq(UIntType(
UnknownWidth())
)
assert not UnknownType().type_eq(SIntType(
IntWidth(10))
)
assert not UnknownType().type_eq(SIntType(
UnknownWidth())
)
assert not UnknownType().type_eq(VectorType(
UIntType(IntWidth(10)), 8)
)
assert not UnknownType().type_eq(VectorType(
UIntType(UnknownWidth()), 8)
)
assert not UnknownType().type_eq(VectorType(
SIntType(IntWidth(10)), 8)
)
assert not UnknownType().type_eq(VectorType(
SIntType(UnknownWidth()), 8)
)
assert not UnknownType().type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not ClockType().type_eq(
UIntType(IntWidth(10))
)
assert not ClockType().type_eq(
UIntType(UnknownWidth())
)
assert not ClockType().type_eq(
SIntType(IntWidth(10))
)
assert not ClockType().type_eq(
SIntType(UnknownWidth())
)
assert not ClockType().type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not ClockType().type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not ClockType().type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not ClockType().type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not ClockType().type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not UIntType(IntWidth(10)).type_eq(
UIntType(UnknownWidth())
)
assert not UIntType(UnknownWidth()).type_eq(
UIntType(IntWidth(10))
)
assert not UIntType(IntWidth(10)).type_eq(
UIntType(IntWidth(8))
)
assert not UIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(10))
)
assert not UIntType(IntWidth(10)).type_eq(
SIntType(UnknownWidth())
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not UIntType(IntWidth(10)).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not UIntType(UnknownWidth()).type_eq(
SIntType(IntWidth(10))
)
assert not UIntType(UnknownWidth()).type_eq(
SIntType(UnknownWidth())
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not UIntType(UnknownWidth()).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not SIntType(IntWidth(10)).type_eq(
SIntType(UnknownWidth())
)
assert not SIntType(IntWidth(10)).type_eq(
SIntType(IntWidth(8))
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not SIntType(IntWidth(10)).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(UIntType(IntWidth(10)), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not SIntType(UnknownWidth()).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(8)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(IntWidth(10)), 6)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UIntType(UnknownWidth()), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(10)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(8)), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(IntWidth(10)), 6)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(SIntType(UnknownWidth()), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(UnknownType(), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
VectorType(ClockType(), 8)
)
assert not VectorType(UIntType(IntWidth(10)), 8).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("c", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("b", VectorType(UIntType(IntWidth(10)), 8)),
Field("a", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(8)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 6)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(UnknownWidth()), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(SIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", UnknownType()),
Field("b", UIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("b", UnknownType()),
Field("a", SIntType(IntWidth(10)))
])
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
UIntType(IntWidth(10))
)
assert not BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8)),
Field("b", UIntType(IntWidth(10)))
]).type_eq(
BundleType([
Field("a", VectorType(UIntType(IntWidth(10)), 8), True),
Field("b", UIntType(IntWidth(10)), True)
])
)