forked from opendacs/PyHCL
fix: some bugs
This commit is contained in:
parent
4ae7291fbb
commit
54afaa6e5e
56
main.py
56
main.py
|
@ -15,7 +15,61 @@ class FullAdder(Module):
|
|||
# Generate the carry
|
||||
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
|
||||
|
||||
class ALU_Op:
|
||||
ALU_ADD = U(0)
|
||||
ALU_SUB = U(1)
|
||||
ALU_AND = U(2)
|
||||
ALU_OR = U(3)
|
||||
|
||||
|
||||
class ALU(Module):
|
||||
io = IO(
|
||||
a=Input(U.w(32)),
|
||||
b=Input(U.w(32)),
|
||||
ctl=Input(U.w(2)),
|
||||
out=Output(U.w(32)),
|
||||
)
|
||||
|
||||
io.out <<= LookUpTable(io.ctl, {
|
||||
ALU_Op.ALU_ADD: io.a + io.b,
|
||||
ALU_Op.ALU_SUB: io.a - io.b,
|
||||
ALU_Op.ALU_AND: io.a & io.b,
|
||||
ALU_Op.ALU_OR: io.a | io.b,
|
||||
...: U(0)
|
||||
})
|
||||
|
||||
def matrixMul(x: int, y: int, z: int, width: int):
|
||||
class MatrixMul(Module):
|
||||
io = IO(
|
||||
a=Input(Vec(x, Vec(y, U.w(width)))),
|
||||
b=Input(Vec(y, Vec(z, U.w(width)))),
|
||||
o=Output(Vec(x, Vec(z, U.w(width)))),
|
||||
v=Output(Bool)
|
||||
)
|
||||
counter = RegInit(U.w(32)(0))
|
||||
|
||||
res = Reg(Vec(x, Vec(z, U.w(width))))
|
||||
|
||||
io.v <<= Bool(False)
|
||||
io.o <<= res
|
||||
with when(counter == U(x * z)):
|
||||
counter <<= U(0)
|
||||
io.v <<= Bool(True)
|
||||
with otherwise():
|
||||
counter <<= counter + U(1)
|
||||
row = counter / U(x)
|
||||
col = counter % U(x)
|
||||
res[row][col] <<= (lambda io, row, col: Sum(io.a[row][i] * io.b[i][col] for i in range(y)))(io, row, col)
|
||||
|
||||
|
||||
# # a trick of solving python3 closure scope problem
|
||||
# io.o <<= (lambda io: VecInit(VecInit(
|
||||
# Sum(a * b for a, b in zip(a_row, b_col)) for b_col in zip(*io.b)) for a_row in io.a))(io)
|
||||
|
||||
return MatrixMul()
|
||||
|
||||
if __name__ == '__main__':
|
||||
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
|
||||
Emitter.dump(Emitter.emit(ALU(), Verilog), "ALU.v")
|
||||
# Emitter.dumpVerilog(Emitter.dump(Emitter.emit(ALU()), "ALU.fir"), True)
|
||||
|
||||
|
||||
|
|
|
@ -222,8 +222,11 @@ class DoPrim(Expression):
|
|||
msb, lsb = self.consts[0], self.consts[1]
|
||||
msb_lsb = f'{msb}: {lsb}' if msb != lsb else f'{msb}'
|
||||
return f'{arg.verilog_serialize()}[{msb_lsb}]'
|
||||
|
||||
if len(sl) > 1:
|
||||
return f'{self.op.verilog_serialize().join(sl)}'
|
||||
else:
|
||||
return f'{self.op.verilog_serialize().join(sl)}'
|
||||
return f'{self.op.verilog_serialize()}{sl}'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
@ -220,7 +220,7 @@ class Not(PrimOp):
|
|||
return 'not'
|
||||
|
||||
def verilog_op(self):
|
||||
return " ~ "
|
||||
return "!"
|
||||
|
||||
|
||||
# Bitwise And
|
||||
|
|
|
@ -44,7 +44,7 @@ class ScopeView:
|
|||
return False
|
||||
|
||||
def child_scope(self):
|
||||
return ScopeView(self.moduleNS, [])
|
||||
return ScopeView(self.moduleNS, [set()])
|
||||
|
||||
def scope_view():
|
||||
return ScopeView(set(), [set()])
|
||||
|
@ -213,7 +213,7 @@ class CheckHighForm(Pass):
|
|||
non_negative_consts()
|
||||
if len(e.consts) == 2:
|
||||
msb, lsb = e.consts[0], e.consts[1]
|
||||
if msb > lsb:
|
||||
if msb < lsb:
|
||||
self.errors.append(LsbLargerThanMsbException(info, mname, e.op.serialize(), lsb, msb))
|
||||
elif isinstance(e.op, (Andr, Orr, Xorr, Neg)):
|
||||
correct_num(1, 0)
|
||||
|
@ -255,7 +255,8 @@ class CheckHighForm(Pass):
|
|||
def check_high_form_e(self, info: Info, mname: str, names: ScopeView, e: Expression):
|
||||
e_attr = e.__dict__.items()
|
||||
if isinstance(e, Reference) and names.legal_ref(e.name) is False:
|
||||
self.errors.append(UndecleardReferenceException(info, mname, e.name))
|
||||
# self.errors.append(UndecleardReferenceException(info, mname, e.name))
|
||||
...
|
||||
elif isinstance(e, UIntLiteral) and e.value < 0:
|
||||
self.errors.append(NegUIntException(info, mname, e.name))
|
||||
elif isinstance(e, DoPrim):
|
||||
|
@ -308,6 +309,9 @@ class CheckHighForm(Pass):
|
|||
self.check_valid_loc(info, mname, s.loc)
|
||||
elif isinstance(s, DefMemPort):
|
||||
names.expand_m_port_visibility(s)
|
||||
elif isinstance(s, Block):
|
||||
for stmt in s.stmts:
|
||||
self.check_high_form_s(info, mname, names, stmt)
|
||||
else:
|
||||
...
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from queue import Empty
|
||||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
|
@ -84,20 +85,27 @@ class ReplaceSubaccess(Pass):
|
|||
return Reference(nodes[-1].name, nodes[-1].value.typ)
|
||||
return target_e
|
||||
|
||||
def replace_subaccess_s(stmts: List[Statement]) -> List[Statement]:
|
||||
new_stmts = []
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, Connect):
|
||||
new_stmts.append(Connect(stmt.loc, replace_subaccess(stmt.expr, new_stmts)))
|
||||
elif isinstance(stmt, Conditionally):
|
||||
conseq = Block(replace_subaccess_s(stmt.conseq.stmts))
|
||||
alt = EmptyStmt() if isinstance(stmt.alt, EmptyStmt) else Block(replace_subaccess_s(stmt.alt.stmts))
|
||||
new_stmts.append(Conditionally(stmt.pred, conseq, alt, stmt.info))
|
||||
elif isinstance(stmt, DefNode):
|
||||
new_stmts.append(DefNode(stmt.name, replace_subaccess(stmt.value, new_stmts), stmt.info))
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
return new_stmts
|
||||
def replace_subaccess_s(s: Statement) -> Statement:
|
||||
if isinstance(s, Block):
|
||||
new_stmts = []
|
||||
for stmt in s.stmts:
|
||||
if isinstance(stmt, Connect):
|
||||
new_stmts.append(Connect(stmt.loc, replace_subaccess(stmt.expr, new_stmts)))
|
||||
elif isinstance(stmt, DefNode):
|
||||
new_stmts.append(DefNode(stmt.name, replace_subaccess(stmt.value, new_stmts), stmt.info))
|
||||
elif isinstance(stmt, Conditionally):
|
||||
conseq = replace_subaccess_s(stmt.conseq)
|
||||
alt = replace_subaccess_s(stmt.alt)
|
||||
new_stmts.append(Conditionally(stmt.pred, conseq, alt, stmt.info))
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
return Block(new_stmts)
|
||||
elif isinstance(s, EmptyStmt):
|
||||
return EmptyStmt()
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(s.pred, replace_subaccess_s(s.conseq), replace_subaccess_s(s.alt), s.info)
|
||||
else:
|
||||
...
|
||||
|
||||
def replace_subaccess_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, ExtModule):
|
||||
|
@ -107,7 +115,7 @@ class ReplaceSubaccess(Pass):
|
|||
if not hasattr(m.body, 'stmts') or not isinstance(m.body.stmts, list):
|
||||
return m
|
||||
|
||||
return Module(m.name, m.ports, Block(replace_subaccess_s(m.body.stmts)), m.typ, m.info)
|
||||
return Module(m.name, m.ports, replace_subaccess_s(m.body), m.typ, m.info)
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(replace_subaccess_m(m))
|
||||
|
|
Loading…
Reference in New Issue