test_firrtl_ir: refactor (#24)

This commit is contained in:
Gaufoo 2019-11-18 01:23:23 +08:00 committed by GitHub
parent 0ec61e16c9
commit c1d76bd49a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 80 deletions

View File

@ -25,7 +25,7 @@ class OpCase(object):
def name_gen():
return "n" + str(uuid.uuid4()).replace("-", "_")
return "n" + str(uuid.uuid4()).replace("-", "")
def u_gen():
@ -70,40 +70,35 @@ def unknown_gen():
def vec_gen():
rand_num = random.randint(1, 100)
seed = random.choices(population=[1, 2, 3, 4, 5],
weights=[0.3, 0.3, 0.3, 0.05, 0.05],
seed = random.choices(population=[1, 2, 3, 4],
weights=[0.4, 0.4, 0.1, 0.1],
k=1)[0]
if seed == 1:
return vec(uw_gen(), rand_num)
elif seed == 2:
return vec(sw_gen(), rand_num)
elif seed == 3:
return vec(unknown_gen(), rand_num)
elif seed == 4:
return vec(vec_gen(), rand_num)
elif seed == 5:
elif seed == 4:
return vec(bdl_gen(), rand_num)
def bdl_gen():
# rand_num = random.randint(1, 10)
rand_num = 1
rand_num = random.randint(1, 5)
fields = []
for i in range(rand_num):
rand_name = name_gen()
rand_bool = random.choice([True, False])
seed = random.choices(population=[1, 2, 3, 4, 5],
weights=[0.3, 0.3, 0.3, 0.05, 0.05],
seed = random.choices(population=[1, 2, 3, 4],
weights=[0.4, 0.4, 0.1, 0.1],
k=1)[0]
if seed == 1:
fields.append(Field(rand_name, uw_gen(), rand_bool))
elif seed == 2:
fields.append(Field(rand_name, sw_gen(), rand_bool))
elif seed == 3:
fields.append(Field(rand_name, unknown_gen(), rand_bool))
elif seed == 4:
fields.append(Field(rand_name, vec_gen(), rand_bool))
elif seed == 5:
elif seed == 4:
fields.append(Field(rand_name, bdl_gen(), rand_bool))
return BundleType(fields)
@ -125,7 +120,7 @@ def obj_gen(case):
def basis_tester(cases):
for case in cases:
for i in range(10):
for i in range(5):
obj = obj_gen(case)
assert OpTypeChecker.check(obj)
assert check(obj)
@ -133,7 +128,17 @@ def basis_tester(cases):
def encounter_error_tester(cases):
for case in cases:
for i in range(10):
for i in range(5):
obj = obj_gen(case)
assert not OpTypeChecker.check(obj)
assert not check(obj)
if __name__ == '__main__':
from io import BytesIO
output = BytesIO()
vec_gen().serialize(output)
output.flush()
print(str(output.getvalue(), "utf-8"))

View File

@ -1,6 +1,7 @@
from py_hcl.firrtl_ir.expr.prim_ops import Add
from py_hcl.firrtl_ir.shortcuts import uw, sw
from py_hcl.firrtl_ir.type import UIntType, SIntType, VectorType, BundleType
from py_hcl.firrtl_ir.type import UIntType, SIntType, VectorType, \
BundleType, UnknownType
from .helper import OpCase, basis_tester, encounter_error_tester
@ -8,76 +9,52 @@ def max_width(x, y):
return max(x.tpe.width.width, y.tpe.width.width)
add_basis_cases = [
OpCase(
Add
).arg_types(
UIntType, UIntType
).res_type(
lambda x, y: uw(max_width(x, y) + 1)
),
def args(*arg_types):
class C:
@staticmethod
def tpe(res_type):
return OpCase(Add).arg_types(*arg_types).res_type(res_type)
OpCase(
Add
).arg_types(
SIntType, SIntType
).res_type(
lambda x, y: sw(max_width(x, y) + 1)
),
return C
add_basis_cases = [
args(UIntType, UIntType).tpe(lambda x, y: uw(max_width(x, y) + 1)),
args(SIntType, SIntType).tpe(lambda x, y: sw(max_width(x, y) + 1)),
]
add_type_wrong_cases = [
OpCase(
Add
).arg_types(
UIntType, VectorType
).res_type(
lambda x, y: uw(5)
),
OpCase(
Add
).arg_types(
UIntType, BundleType
).res_type(
lambda x, y: uw(5)
),
OpCase(
Add
).arg_types(
SIntType, VectorType
).res_type(
lambda x, y: sw(5)
),
OpCase(
Add
).arg_types(
SIntType, BundleType
).res_type(
lambda x, y: sw(5)
),
args(UnknownType, UnknownType).tpe(lambda x, y: uw(32)),
args(UnknownType, UIntType).tpe(lambda x, y: uw(32)),
args(UnknownType, SIntType).tpe(lambda x, y: sw(32)),
args(UnknownType, VectorType).tpe(lambda x, y: sw(32)),
args(UIntType, VectorType).tpe(lambda x, y: uw(32)),
args(UIntType, BundleType).tpe(lambda x, y: uw(32)),
args(UIntType, UnknownType).tpe(lambda x, y: uw(32)),
args(SIntType, VectorType).tpe(lambda x, y: sw(32)),
args(SIntType, BundleType).tpe(lambda x, y: sw(32)),
args(SIntType, UnknownType).tpe(lambda x, y: sw(32)),
args(VectorType, UIntType).tpe(lambda x, y: uw(32)),
args(VectorType, SIntType).tpe(lambda x, y: sw(32)),
args(VectorType, VectorType).tpe(lambda x, y: uw(32)),
args(VectorType, BundleType).tpe(lambda x, y: uw(32)),
args(BundleType, UIntType).tpe(lambda x, y: uw(32)),
args(BundleType, SIntType).tpe(lambda x, y: uw(32)),
args(BundleType, VectorType).tpe(lambda x, y: uw(32)),
args(BundleType, UnknownType).tpe(lambda x, y: uw(32)),
args(BundleType, BundleType).tpe(lambda x, y: uw(32)),
]
add_width_wrong_cases = [
OpCase(
Add
).arg_types(
UIntType, UIntType
).res_type(
lambda x, y: uw(max_width(x, y))
),
OpCase(
Add
).arg_types(
SIntType, SIntType
).res_type(
lambda x, y: sw(max_width(x, y))
),
args(UIntType, UIntType).tpe(lambda x, y: uw(max_width(x, y))),
args(SIntType, SIntType).tpe(lambda x, y: sw(max_width(x, y))),
args(UIntType, UIntType).tpe(lambda x, y: uw(max_width(x, y) - 1)),
args(SIntType, SIntType).tpe(lambda x, y: sw(max_width(x, y) - 1)),
args(SIntType, SIntType).tpe(lambda x, y: sw(1)),
]
basis_tester(add_basis_cases)
encounter_error_tester(add_type_wrong_cases)
encounter_error_tester(add_width_wrong_cases)
def test_add():
basis_tester(add_basis_cases)
encounter_error_tester(add_type_wrong_cases)
encounter_error_tester(add_width_wrong_cases)