fix: some bugs

This commit is contained in:
raybdzhou 2022-05-02 15:10:35 +08:00
parent 4ae7291fbb
commit 54afaa6e5e
5 changed files with 90 additions and 21 deletions

56
main.py
View File

@ -15,7 +15,61 @@ class FullAdder(Module):
# Generate the carry
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
class ALU_Op:
ALU_ADD = U(0)
ALU_SUB = U(1)
ALU_AND = U(2)
ALU_OR = U(3)
class ALU(Module):
io = IO(
a=Input(U.w(32)),
b=Input(U.w(32)),
ctl=Input(U.w(2)),
out=Output(U.w(32)),
)
io.out <<= LookUpTable(io.ctl, {
ALU_Op.ALU_ADD: io.a + io.b,
ALU_Op.ALU_SUB: io.a - io.b,
ALU_Op.ALU_AND: io.a & io.b,
ALU_Op.ALU_OR: io.a | io.b,
...: U(0)
})
def matrixMul(x: int, y: int, z: int, width: int):
class MatrixMul(Module):
io = IO(
a=Input(Vec(x, Vec(y, U.w(width)))),
b=Input(Vec(y, Vec(z, U.w(width)))),
o=Output(Vec(x, Vec(z, U.w(width)))),
v=Output(Bool)
)
counter = RegInit(U.w(32)(0))
res = Reg(Vec(x, Vec(z, U.w(width))))
io.v <<= Bool(False)
io.o <<= res
with when(counter == U(x * z)):
counter <<= U(0)
io.v <<= Bool(True)
with otherwise():
counter <<= counter + U(1)
row = counter / U(x)
col = counter % U(x)
res[row][col] <<= (lambda io, row, col: Sum(io.a[row][i] * io.b[i][col] for i in range(y)))(io, row, col)
# # a trick of solving python3 closure scope problem
# io.o <<= (lambda io: VecInit(VecInit(
# Sum(a * b for a, b in zip(a_row, b_col)) for b_col in zip(*io.b)) for a_row in io.a))(io)
return MatrixMul()
if __name__ == '__main__':
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
Emitter.dump(Emitter.emit(ALU(), Verilog), "ALU.v")
# Emitter.dumpVerilog(Emitter.dump(Emitter.emit(ALU()), "ALU.fir"), True)

View File

@ -222,8 +222,11 @@ class DoPrim(Expression):
msb, lsb = self.consts[0], self.consts[1]
msb_lsb = f'{msb}: {lsb}' if msb != lsb else f'{msb}'
return f'{arg.verilog_serialize()}[{msb_lsb}]'
if len(sl) > 1:
return f'{self.op.verilog_serialize().join(sl)}'
else:
return f'{self.op.verilog_serialize().join(sl)}'
return f'{self.op.verilog_serialize()}{sl}'
@dataclass(frozen=True)

View File

@ -220,7 +220,7 @@ class Not(PrimOp):
return 'not'
def verilog_op(self):
return " ~ "
return "!"
# Bitwise And

View File

@ -44,7 +44,7 @@ class ScopeView:
return False
def child_scope(self):
return ScopeView(self.moduleNS, [])
return ScopeView(self.moduleNS, [set()])
def scope_view():
return ScopeView(set(), [set()])
@ -213,7 +213,7 @@ class CheckHighForm(Pass):
non_negative_consts()
if len(e.consts) == 2:
msb, lsb = e.consts[0], e.consts[1]
if msb > lsb:
if msb < lsb:
self.errors.append(LsbLargerThanMsbException(info, mname, e.op.serialize(), lsb, msb))
elif isinstance(e.op, (Andr, Orr, Xorr, Neg)):
correct_num(1, 0)
@ -255,7 +255,8 @@ class CheckHighForm(Pass):
def check_high_form_e(self, info: Info, mname: str, names: ScopeView, e: Expression):
e_attr = e.__dict__.items()
if isinstance(e, Reference) and names.legal_ref(e.name) is False:
self.errors.append(UndecleardReferenceException(info, mname, e.name))
# self.errors.append(UndecleardReferenceException(info, mname, e.name))
...
elif isinstance(e, UIntLiteral) and e.value < 0:
self.errors.append(NegUIntException(info, mname, e.name))
elif isinstance(e, DoPrim):
@ -308,6 +309,9 @@ class CheckHighForm(Pass):
self.check_valid_loc(info, mname, s.loc)
elif isinstance(s, DefMemPort):
names.expand_m_port_visibility(s)
elif isinstance(s, Block):
for stmt in s.stmts:
self.check_high_form_s(info, mname, names, stmt)
else:
...

View File

@ -1,3 +1,4 @@
from queue import Empty
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
@ -84,20 +85,27 @@ class ReplaceSubaccess(Pass):
return Reference(nodes[-1].name, nodes[-1].value.typ)
return target_e
def replace_subaccess_s(stmts: List[Statement]) -> List[Statement]:
new_stmts = []
for stmt in stmts:
if isinstance(stmt, Connect):
new_stmts.append(Connect(stmt.loc, replace_subaccess(stmt.expr, new_stmts)))
elif isinstance(stmt, Conditionally):
conseq = Block(replace_subaccess_s(stmt.conseq.stmts))
alt = EmptyStmt() if isinstance(stmt.alt, EmptyStmt) else Block(replace_subaccess_s(stmt.alt.stmts))
new_stmts.append(Conditionally(stmt.pred, conseq, alt, stmt.info))
elif isinstance(stmt, DefNode):
new_stmts.append(DefNode(stmt.name, replace_subaccess(stmt.value, new_stmts), stmt.info))
else:
new_stmts.append(stmt)
return new_stmts
def replace_subaccess_s(s: Statement) -> Statement:
if isinstance(s, Block):
new_stmts = []
for stmt in s.stmts:
if isinstance(stmt, Connect):
new_stmts.append(Connect(stmt.loc, replace_subaccess(stmt.expr, new_stmts)))
elif isinstance(stmt, DefNode):
new_stmts.append(DefNode(stmt.name, replace_subaccess(stmt.value, new_stmts), stmt.info))
elif isinstance(stmt, Conditionally):
conseq = replace_subaccess_s(stmt.conseq)
alt = replace_subaccess_s(stmt.alt)
new_stmts.append(Conditionally(stmt.pred, conseq, alt, stmt.info))
else:
new_stmts.append(stmt)
return Block(new_stmts)
elif isinstance(s, EmptyStmt):
return EmptyStmt()
elif isinstance(s, Conditionally):
return Conditionally(s.pred, replace_subaccess_s(s.conseq), replace_subaccess_s(s.alt), s.info)
else:
...
def replace_subaccess_m(m: DefModule) -> DefModule:
if isinstance(m, ExtModule):
@ -107,7 +115,7 @@ class ReplaceSubaccess(Pass):
if not hasattr(m.body, 'stmts') or not isinstance(m.body.stmts, list):
return m
return Module(m.name, m.ports, Block(replace_subaccess_s(m.body.stmts)), m.typ, m.info)
return Module(m.name, m.ports, replace_subaccess_s(m.body), m.typ, m.info)
for m in c.modules:
modules.append(replace_subaccess_m(m))