test_firrtl_ir: add def tests (instance memory node register wire) (#21)

* test_firrtl_ir: add def tests (instance memory node register wire)

* test_firrtl_ir: make style check happy
This commit is contained in:
Gaufoo 2019-11-17 20:41:07 +08:00 committed by GitHub
parent bf6d12045a
commit 4809e03592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 468 additions and 23 deletions

View File

@ -32,7 +32,7 @@ class DefMemReadPort(Statement):
self.clock_ref.serialize(output)
class DefMemWriteRead(Statement):
class DefMemWritePort(Statement):
def __init__(self, name, mem_ref, index_ref, clock_ref):
self.name = name
self.mem_ref = mem_ref

View File

@ -33,8 +33,8 @@ class DefInitRegister(Statement):
output.write(b", ")
self.clock_ref.serialize(output)
output.write(b" with :\n")
output.write(serialize_str(" " * indent))
output.write(b" (")
output.write(serialize_str(" " * (indent + 1)))
output.write(b"reset => (")
self.reset_ref.serialize(output)
output.write(b", ")
self.init_ref.serialize(output)

View File

@ -1,3 +1,4 @@
from .stmt.definition import DefinitionTypeChecker
from .stmt.block import BlockTypeChecker
from .stmt.conditionally import ConditionallyTypeChecker
from ..stmt import Statement
@ -27,6 +28,7 @@ final_map = {
**ConnectTypeChecker.connect_checker_map,
**ConditionallyTypeChecker.conditionally_checker_map,
**BlockTypeChecker.block_checker_map,
**DefinitionTypeChecker.definition_checker_map,
# simple statement
EmptyStmt: true,

View File

@ -31,8 +31,8 @@ def checker(accessor):
@checker(SubField)
def _(sub_field):
from .. import check
if not check(sub_field.bundle_ref):
from .. import check_all_expr
if not check_all_expr(sub_field.bundle_ref):
return False
if not type_in(sub_field.bundle_ref.tpe, BundleType):
@ -54,8 +54,8 @@ def _(sub_field):
@checker(SubIndex)
def _(sub_index):
from .. import check
if not check(sub_index.vector_ref):
from .. import check_all_expr
if not check_all_expr(sub_index.vector_ref):
return False
if not type_in(sub_index.vector_ref.tpe, VectorType):

View File

@ -169,8 +169,8 @@ def _(binary_bit):
@checker(Not)
def _(n):
from .. import check
if not check(n.arg):
from .. import check_all_expr
if not check_all_expr(n.arg):
return False
if not check_all_same_uint_sint(n.arg.tpe):
@ -188,8 +188,8 @@ def _(n):
@checker(Neg)
def _(neg):
from .. import check
if not check(neg.arg):
from .. import check_all_expr
if not check_all_expr(neg.arg):
return False
if not check_all_same_uint_sint(neg.arg.tpe):
@ -228,8 +228,8 @@ def _(cat):
@checker(Bits)
def _(bits):
from .. import check
if not check(bits.ir_arg):
from .. import check_all_expr
if not check_all_expr(bits.ir_arg):
return False
if not check_all_same_uint_sint(bits.ir_arg.tpe):
@ -253,8 +253,8 @@ def _(bits):
@checker(AsUInt)
def _(as_uint):
from .. import check
if not check(as_uint.arg):
from .. import check_all_expr
if not check_all_expr(as_uint.arg):
return False
if not type_in(as_uint.arg.tpe, UIntType, SIntType, ClockType):
@ -274,8 +274,8 @@ def _(as_uint):
@checker(AsSInt)
def _(as_sint):
from .. import check
if not check(as_sint.arg):
from .. import check_all_expr
if not check_all_expr(as_sint.arg):
return False
if not type_in(as_sint.arg.tpe, UIntType, SIntType, ClockType):
@ -295,8 +295,8 @@ def _(as_sint):
@checker(Shl)
def _(shl):
from .. import check
if not check(shl.ir_arg):
from .. import check_all_expr
if not check_all_expr(shl.ir_arg):
return False
if not check_all_same_uint_sint(shl.ir_arg.tpe, shl.tpe):
@ -311,8 +311,8 @@ def _(shl):
@checker(Shr)
def _(shr):
from .. import check
if not check(shr.ir_arg):
from .. import check_all_expr
if not check_all_expr(shr.ir_arg):
return False
if not check_all_same_uint_sint(shr.ir_arg.tpe, shr.tpe):

View File

@ -0,0 +1,171 @@
from ...shortcuts import uw
from ...type_measurer import equal
from ...stmt.defn.circuit import DefCircuit
from ...stmt.defn.instance import DefInstance
from ...stmt.defn.memory import DefMemory, DefMemReadPort, DefMemWritePort
from ...stmt.defn.module import DefModule, InputPort, OutputPort, DefExtModule
from ...stmt.defn.node import DefNode
from ...type import VectorType, ClockType, UIntType
from ..utils import type_in
from ...stmt.defn.register import DefRegister, DefInitRegister
from ...stmt.defn.wire import DefWire
class DefinitionTypeChecker(object):
definition_checker_map = {}
@staticmethod
def check(op_obj):
try:
return DefinitionTypeChecker \
.definition_checker_map[type(op_obj)](op_obj)
except KeyError:
raise NotImplementedError(type(op_obj))
def checker(definition):
def f(func):
DefinitionTypeChecker.definition_checker_map[definition] = func
return func
return f
###############################################################
# TYPE CHECKERS #
###############################################################
@checker(DefWire)
def _(_):
return True
@checker(DefInstance)
def _(_):
return True
@checker(DefRegister)
def _(reg):
from ...type_checker import check_all_expr
if not check_all_expr(reg.clock_ref):
return False
if not type_in(reg.clock_ref.tpe, ClockType):
return False
return True
@checker(DefInitRegister)
def _(reg):
from ...type_checker import check_all_expr
if not check_all_expr(reg.clock_ref, reg.reset_ref, reg.init_ref):
return False
if not type_in(reg.clock_ref.tpe, ClockType):
return False
if not equal(reg.reset_ref.tpe, uw(1)):
return False
if not equal(reg.init_ref.tpe, reg.tpe):
return False
return True
@checker(DefNode)
def _(node):
from ...type_checker import check_all_expr
if not check_all_expr(node.expr_ref):
return False
return True
@checker(DefMemory)
def _(mem):
if not type_in(mem.tpe, VectorType):
return False
return True
@checker(DefMemReadPort)
def _(mem_read):
from ...type_checker import check_all_expr
if not check_all_expr(mem_read.mem_ref, mem_read.index_ref,
mem_read.clock_ref):
return False
if not type_in(mem_read.mem_ref.tpe, VectorType):
return False
if not type_in(mem_read.clock_ref.tpe, ClockType):
return False
if not type_in(mem_read.index_ref.tpe, UIntType):
return False
return True
@checker(DefMemWritePort)
def _(mem_write):
from ...type_checker import check_all_expr
if not check_all_expr(mem_write.mem_ref, mem_write.index_ref,
mem_write.clock_ref):
return False
if not type_in(mem_write.mem_ref.tpe, VectorType):
return False
if not type_in(mem_write.clock_ref.tpe, ClockType):
return False
if not type_in(mem_write.index_ref.tpe, UIntType):
return False
return True
@checker(DefModule)
def _(mod):
from ...type_checker import check_all_stmt
if not check_all_stmt(mod.body):
return False
for port in mod.ports:
if not type_in(port, InputPort, OutputPort):
return False
return True
@checker(DefExtModule)
def _(mod):
for port in mod.ports:
if not type_in(port, InputPort, OutputPort):
return False
return True
@checker(DefCircuit)
def _(circuit):
from ...type_checker import check_all_stmt
if not check_all_stmt(*circuit.def_modules):
return False
name_found = False
for mod in circuit.def_modules:
if not type_in(mod, DefModule, DefExtModule):
return False
if mod.name == circuit.name:
name_found = True
if not name_found:
return False
return True

View File

@ -13,7 +13,10 @@ def test_conditionally_basis():
cn = Conditionally(n("a", uw(1)), s1, s2)
assert ConditionallyTypeChecker.check(cn)
assert check(cn)
serialize_stmt_equal(cn, "when a :\n skip\nelse :\n a <= b")
serialize_stmt_equal(cn, "when a :\n"
" skip\n"
"else :\n"
" a <= b")
s1 = Block([
Connect(n("a", uw(8)), n("b", uw(8))),
@ -24,7 +27,11 @@ def test_conditionally_basis():
assert ConditionallyTypeChecker.check(cn)
assert check(cn)
serialize_stmt_equal(
cn, 'when UInt<1>("1") :\n a <= b\n c <= d\nelse :\n skip')
cn, 'when UInt<1>("1") :\n'
' a <= b\n'
' c <= d\n'
'else :\n'
' skip')
def test_conditionally_type_wrong():

View File

@ -0,0 +1,10 @@
from py_hcl.firrtl_ir.stmt.defn.instance import DefInstance
from py_hcl.firrtl_ir.type_checker import check, DefinitionTypeChecker
from ...utils import serialize_stmt_equal
def test_instance_basis():
i1 = DefInstance("i1", "ALU")
assert DefinitionTypeChecker.check(i1)
assert check(i1)
serialize_stmt_equal(i1, 'inst i1 of ALU')

View File

@ -0,0 +1,131 @@
from py_hcl.firrtl_ir.shortcuts import vec, uw, bdl, n, u, w, s
from py_hcl.firrtl_ir.stmt.defn.memory import DefMemory, \
DefMemWritePort, DefMemReadPort
from py_hcl.firrtl_ir.type import ClockType
from py_hcl.firrtl_ir.type_checker import check, DefinitionTypeChecker
from ...utils import serialize_stmt_equal
def test_memory_basis():
mem = DefMemory("m", vec(uw(8), 10))
assert DefinitionTypeChecker.check(mem)
assert check(mem)
serialize_stmt_equal(mem, 'cmem m : UInt<8>[10]')
mem = DefMemory("m", vec(bdl(a=(uw(8), False)), 10))
assert DefinitionTypeChecker.check(mem)
assert check(mem)
serialize_stmt_equal(mem, 'cmem m : {a : UInt<8>}[10]')
def test_memory_type_wrong():
mem = DefMemory("m", bdl(a=(vec(uw(8), 10), False)))
assert not DefinitionTypeChecker.check(mem)
assert not check(mem)
mem = DefMemory("m", uw(9))
assert not DefinitionTypeChecker.check(mem)
assert not check(mem)
def test_read_port_basis():
mem_ref = n("m", vec(uw(8), 10))
mr = DefMemReadPort("mr", mem_ref, u(2, w(8)), n("clock", ClockType()))
assert DefinitionTypeChecker.check(mr)
assert check(mr)
serialize_stmt_equal(mr, 'read mport mr = m[UInt<8>("2")], clock')
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mr = DefMemReadPort("mr", mem_ref, n("a", uw(2)), n("clock", ClockType()))
assert DefinitionTypeChecker.check(mr)
assert check(mr)
serialize_stmt_equal(mr, 'read mport mr = m[a], clock')
def test_read_port_clock_wrong():
mem_ref = n("m", vec(uw(8), 10))
mr = DefMemReadPort("mr", mem_ref, u(2, w(8)), n("clock", uw(1)))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mr = DefMemReadPort("mr", mem_ref, n("a", uw(2)), u(0, w(1)))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
def test_read_port_index_wrong():
mem_ref = n("m", vec(uw(8), 10))
mr = DefMemReadPort("mr", mem_ref, s(2, w(8)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mr = DefMemReadPort("mr", mem_ref,
n("a", vec(uw(1), 10)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
def test_read_port_mem_wrong():
mem_ref = n("m", bdl(a=(vec(uw(8), 10), False)))
mr = DefMemReadPort("mr", mem_ref, u(2, w(8)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
mem_ref = n("m", uw(9))
mr = DefMemReadPort("mr", mem_ref, n("a", uw(2)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mr)
assert not check(mr)
def test_write_port_basis():
mem_ref = n("m", vec(uw(8), 10))
mw = DefMemWritePort("mw", mem_ref, u(2, w(8)), n("clock", ClockType()))
assert DefinitionTypeChecker.check(mw)
assert check(mw)
serialize_stmt_equal(mw, 'write mport mw = m[UInt<8>("2")], clock')
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mw = DefMemWritePort("mw", mem_ref, n("a", uw(2)), n("clock", ClockType()))
assert DefinitionTypeChecker.check(mw)
assert check(mw)
serialize_stmt_equal(mw, 'write mport mw = m[a], clock')
def test_write_port_clock_wrong():
mem_ref = n("m", vec(uw(8), 10))
mw = DefMemWritePort("mw", mem_ref, u(2, w(8)), n("clock", uw(1)))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mw = DefMemWritePort("mw", mem_ref, n("a", uw(2)), u(0, w(1)))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)
def test_write_port_index_wrong():
mem_ref = n("m", vec(uw(8), 10))
mw = DefMemWritePort("mw", mem_ref, s(2, w(8)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)
mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10))
mw = DefMemWritePort("mw", mem_ref,
n("a", vec(uw(1), 10)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)
def test_write_port_mem_wrong():
mem_ref = n("m", bdl(a=(vec(uw(8), 10), False)))
mw = DefMemWritePort("mw", mem_ref, u(2, w(8)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)
mem_ref = n("m", uw(9))
mw = DefMemWritePort("mw", mem_ref,
n("a", uw(2)), n("clock", ClockType()))
assert not DefinitionTypeChecker.check(mw)
assert not check(mw)

View File

@ -0,0 +1,27 @@
from py_hcl.firrtl_ir.expr.accessor import SubIndex
from py_hcl.firrtl_ir.shortcuts import n, s, w, vec, uw
from py_hcl.firrtl_ir.stmt.defn.node import DefNode
from py_hcl.firrtl_ir.type_checker import check, DefinitionTypeChecker
from ...utils import serialize_stmt_equal
def test_node_basis():
node = DefNode("n1", s(20, w(6)))
assert DefinitionTypeChecker.check(node)
assert check(node)
serialize_stmt_equal(node, 'node n1 = SInt<6>("14")')
node = DefNode("n2", SubIndex(n("v", vec(uw(8), 10)), 7, uw(8)))
assert DefinitionTypeChecker.check(node)
assert check(node)
serialize_stmt_equal(node, 'node n2 = v[7]')
def test_node_expr_wrong():
node = DefNode("n1", s(20, w(5)))
assert not DefinitionTypeChecker.check(node)
assert not check(node)
node = DefNode("n2", SubIndex(n("v", vec(uw(8), 10)), 10, uw(8)))
assert not DefinitionTypeChecker.check(node)
assert not check(node)

View File

@ -0,0 +1,79 @@
from py_hcl.firrtl_ir.shortcuts import n, s, w, vec, uw, sw, u
from py_hcl.firrtl_ir.stmt.defn.register import DefRegister, DefInitRegister
from py_hcl.firrtl_ir.type import ClockType
from py_hcl.firrtl_ir.type_checker import check, DefinitionTypeChecker
from ...utils import serialize_stmt_equal
def test_register_basis():
r1 = DefRegister("r1", uw(8), n("clock", ClockType()))
assert DefinitionTypeChecker.check(r1)
assert check(r1)
serialize_stmt_equal(r1, 'reg r1 : UInt<8>, clock')
r2 = DefRegister("r2", vec(uw(8), 10), n("clock", ClockType()))
assert DefinitionTypeChecker.check(r2)
assert check(r2)
serialize_stmt_equal(r2, 'reg r2 : UInt<8>[10], clock')
def test_register_clock_wrong():
r1 = DefRegister("r1", uw(8), n("clock", uw(1)))
assert not DefinitionTypeChecker.check(r1)
assert not check(r1)
r2 = DefRegister("r2", vec(uw(8), 10), n("clock", sw(1)))
assert not DefinitionTypeChecker.check(r2)
assert not check(r2)
def test_init_register_basis():
r1 = DefInitRegister("r1", uw(8),
n("clock", ClockType()), n("r", uw(1)), u(5, w(8)))
assert DefinitionTypeChecker.check(r1)
assert check(r1)
serialize_stmt_equal(r1, 'reg r1 : UInt<8>, clock with :\n'
' reset => (r, UInt<8>("5"))')
r2 = DefInitRegister("r2", sw(8),
n("clock", ClockType()), u(0, w(1)), s(5, w(8)))
assert DefinitionTypeChecker.check(r2)
assert check(r2)
serialize_stmt_equal(r2, 'reg r2 : SInt<8>, clock with :\n'
' reset => (UInt<1>("0"), SInt<8>("5"))')
def test_init_register_clock_wrong():
r1 = DefInitRegister("r1", uw(8),
n("clock", uw(1)), n("r", uw(1)), u(5, w(8)))
assert not DefinitionTypeChecker.check(r1)
assert not check(r1)
r2 = DefInitRegister("r2", sw(8),
n("clock", sw(1)), u(0, w(1)), s(5, w(8)))
assert not DefinitionTypeChecker.check(r2)
assert not check(r2)
def test_init_register_reset_wrong():
r1 = DefInitRegister("r1", uw(8),
n("clock", ClockType()), n("r", sw(1)), u(5, w(8)))
assert not DefinitionTypeChecker.check(r1)
assert not check(r1)
r2 = DefInitRegister("r2", sw(8),
n("clock", ClockType()), s(0, w(1)), s(5, w(8)))
assert not DefinitionTypeChecker.check(r2)
assert not check(r2)
def test_init_register_type_not_match():
r1 = DefInitRegister("r1", uw(8),
n("clock", ClockType()), n("r", uw(1)), s(5, w(8)))
assert not DefinitionTypeChecker.check(r1)
assert not check(r1)
r2 = DefInitRegister("r2", uw(8),
n("clock", ClockType()), u(0, w(1)), s(5, w(8)))
assert not DefinitionTypeChecker.check(r2)
assert not check(r2)

View File

@ -0,0 +1,16 @@
from py_hcl.firrtl_ir.shortcuts import uw, bdl, sw
from py_hcl.firrtl_ir.stmt.defn.wire import DefWire
from py_hcl.firrtl_ir.type_checker import check, DefinitionTypeChecker
from ...utils import serialize_stmt_equal
def test_wire_basis():
wire = DefWire("w1", uw(8))
assert DefinitionTypeChecker.check(wire)
assert check(wire)
serialize_stmt_equal(wire, 'wire w1 : UInt<8>')
wire = DefWire("w2", bdl(a=(uw(8), True), b=(sw(8), False)))
assert DefinitionTypeChecker.check(wire)
assert check(wire)
serialize_stmt_equal(wire, 'wire w2 : {flip a : UInt<8>, b : SInt<8>}')

View File

@ -1,6 +1,8 @@
from py_hcl.firrtl_ir.stmt.empty import EmptyStmt
from py_hcl.firrtl_ir.type_checker import check
from ..utils import serialize_stmt_equal
def test_empty_serialize():
serialize_stmt_equal(EmptyStmt(), "skip")
assert check(EmptyStmt())