forked from opendacs/PyHCL
refactor: remove infer_types & infer_widths
This commit is contained in:
parent
db3333abda
commit
c94f326cd7
4
main.py
4
main.py
|
@ -17,8 +17,8 @@ class FullAdder(Module):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# emit high firrtl
|
||||
# Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
|
||||
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
|
||||
# emit lowered firrtl
|
||||
# Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
|
||||
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
|
||||
# emit verilog
|
||||
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from pyhcl.passes import check_form, check_types, check_widths, check_flows, infer_types, infer_widths
|
||||
from pyhcl.passes import check_form, check_types, check_widths, check_flows, auto_inferring
|
||||
from pyhcl.ir.low_ir import *
|
||||
class CheckAndInfer:
|
||||
@staticmethod
|
||||
def run(c: Circuit):
|
||||
c = check_form.CheckHighForm(c).run()
|
||||
c = infer_types.InferTypes().run(c)
|
||||
c = infer_widths.InferWidths().run(c)
|
||||
c = auto_inferring.AutoInferring().run(c)
|
||||
c = check_types.CheckTypes().run(c)
|
||||
c = check_flows.CheckFlow().run(c)
|
||||
c = check_widths.CheckWidths().run(c)
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from pyhcl.ir.low_ir import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoInferring:
|
||||
max_width: int = 0
|
||||
|
||||
def run(self, c: Circuit):
|
||||
modules: List[Module] = []
|
||||
|
||||
def auto_inferring_t(t: Type) -> Type:
|
||||
if isinstance(t, UIntType):
|
||||
if t.width.width == 0:
|
||||
return UIntType(IntWidth(self.max_width))
|
||||
else:
|
||||
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
|
||||
return t
|
||||
elif isinstance(t, SIntType):
|
||||
if t.width.width == 0:
|
||||
return SIntType(IntWidth(self.max_width))
|
||||
else:
|
||||
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
|
||||
return t
|
||||
elif isinstance(t, (ClockType, ResetType, AsyncResetType)):
|
||||
return t
|
||||
elif isinstance(t, VectorType):
|
||||
return VectorType(auto_inferring_t(t.typ), t.size)
|
||||
elif isinstance(t, MemoryType):
|
||||
return MemoryType(auto_inferring_t(t.typ), t.size)
|
||||
elif isinstance(t, BundleType):
|
||||
return BundleType([Field(fx.name, fx.flip, auto_inferring_t(fx.typ)) for fx in t.fields])
|
||||
else:
|
||||
return t
|
||||
|
||||
def auto_inferring_e(e: Expression, inferring_map: Dict[str, Type]) -> Expression:
|
||||
if isinstance(e, Mux):
|
||||
return Mux(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.tval, inferring_map),
|
||||
auto_inferring_e(e.fval, inferring_map), auto_inferring_t(e.typ))
|
||||
elif isinstance(e, ValidIf):
|
||||
return ValidIf(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.value, inferring_map), auto_inferring_t(e.typ))
|
||||
elif isinstance(e, DoPrim):
|
||||
return DoPrim(e.op, [auto_inferring_e(arg, inferring_map) for arg in e.args], e.consts, auto_inferring_t(e.typ))
|
||||
elif isinstance(e, UIntLiteral):
|
||||
if e.width.width < get_binary_width(e.value):
|
||||
return UIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
|
||||
else:
|
||||
return e
|
||||
elif isinstance(e, SIntLiteral):
|
||||
if e.width.width < get_binary_width(e.value) + 1:
|
||||
return SIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
|
||||
else:
|
||||
return e
|
||||
elif isinstance(e, Reference):
|
||||
typ = inferring_map[e.name] if inferring_map[e.name] else auto_inferring_t(e.typ)
|
||||
return Reference(e.name, typ)
|
||||
elif isinstance(e, SubField):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
typ = e.typ
|
||||
for fx in expr.typ.fields:
|
||||
if fx.name == e.name:
|
||||
typ = fx.typ
|
||||
return SubField(expr, e.name, typ)
|
||||
elif isinstance(e, SubIndex):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
return SubIndex(e.name, expr, e.value, expr.typ.typ)
|
||||
elif isinstance(e, SubAccess):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
index = auto_inferring_e(e.index, inferring_map)
|
||||
return SubAccess(expr, index, expr.typ.typ)
|
||||
else:
|
||||
return e
|
||||
|
||||
def auto_inferring_s(s: Statement, inferring_map: Dict[str, Type]) -> Statement:
|
||||
if isinstance(s, Block):
|
||||
stmts: List[Statement] = []
|
||||
for sx in s.stmts:
|
||||
stmts.append(auto_inferring_s(sx, inferring_map))
|
||||
return Block(stmts)
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(auto_inferring_e(s.pred, inferring_map), auto_inferring_s(s.conseq, inferring_map), auto_inferring_s(s.alt, inferring_map), s.info)
|
||||
elif isinstance(s, DefRegister):
|
||||
clock = auto_inferring_e(s.clock, inferring_map)
|
||||
reset = auto_inferring_e(s.reset, inferring_map)
|
||||
init = auto_inferring_e(s.init, inferring_map)
|
||||
typ = auto_inferring_t(s.typ)
|
||||
inferring_map[s.name] = typ
|
||||
return DefRegister(s.name, typ, clock, reset, init, s.info)
|
||||
elif isinstance(s, DefWire):
|
||||
inferring_map[s.name] = auto_inferring_t(s.typ)
|
||||
return s
|
||||
elif isinstance(s, DefMemory):
|
||||
inferring_map[s.name] = auto_inferring_t(s.memType)
|
||||
return s
|
||||
elif isinstance(s, DefNode):
|
||||
value = auto_inferring_e(s.value, inferring_map)
|
||||
inferring_map[s.name] = value.typ
|
||||
return DefNode(s.name, value, s.info)
|
||||
elif isinstance(s, DefMemPort):
|
||||
clk = auto_inferring_e(s.clk, inferring_map)
|
||||
index = auto_inferring_e(s.index, inferring_map)
|
||||
return DefMemPort(s.name, s.mem, index, clk, s.rw, s.info)
|
||||
elif isinstance(s, Connect):
|
||||
return Connect(auto_inferring_e(s.loc, inferring_map), auto_inferring_e(s.expr, inferring_map), s.info, s.blocking, s.bidirection, s.mem)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def auto_inferring_m(m: DefModule, inferring_map: Dict[str, Type]) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
ports: List[Port] = []
|
||||
for p in m.ports:
|
||||
inferring_map[p.name] = auto_inferring_t(p.typ)
|
||||
ports.append(Port(p.name, p.direction, inferring_map[p.name], p.info))
|
||||
body = auto_inferring_s(m.body, inferring_map)
|
||||
return Module(m.name, ports, body, m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
inferring_map: Dict[str, Type] = {}
|
||||
modules.append(auto_inferring_m(m, inferring_map))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -1,76 +0,0 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from pyhcl.passes.wir import *
|
||||
from pyhcl.passes.utils import module_type, field_type, sub_type, mux_type, get_or_else
|
||||
|
||||
class InferTypes(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
mtyps: Dict[str, Type] = {}
|
||||
for m in c.modules:
|
||||
mtyps[m.name] = module_type(m)
|
||||
|
||||
def infer_types_e(typs: Dict[str, Type], e: Expression) -> Expression:
|
||||
if isinstance(e, Reference):
|
||||
return Reference(e.name, get_or_else(e.name in typs.keys(), typs[e.name], UnknownType))
|
||||
elif isinstance(e, SubField):
|
||||
return SubField(e.expr, e.name, field_type(e.expr.typ, e.name))
|
||||
elif isinstance(e, SubIndex):
|
||||
return SubIndex(e.name, e.expr, e.value, sub_type(e.expr.typ))
|
||||
elif isinstance(e, SubAccess):
|
||||
return SubAccess(e.expr, e.index, sub_type(e.expr.typ))
|
||||
elif isinstance(e, DoPrim):
|
||||
return DoPrim(e.op, e.args, e.consts, e.typ)
|
||||
elif isinstance(e, Mux):
|
||||
return Mux(e.cond, e.tval, e.fval, mux_type(e.tval, e.fval))
|
||||
elif isinstance(e, ValidIf):
|
||||
return ValidIf(e.cond, e.value, e.value.typ)
|
||||
else:
|
||||
return e
|
||||
|
||||
def infer_types_s(typs: Dict[str, Type], s: Statement) -> Statement:
|
||||
if isinstance(s, DefRegister):
|
||||
typs[s.name] = s.typ
|
||||
clock = infer_types_e(typs, s.clock) if hasattr(s, 'clock') and isinstance(s.clock, Expression) else None
|
||||
reset = s.reset if hasattr(s, 'reset') and isinstance(s.init, Expression) else Reference('reset', ResetType())
|
||||
init = s.init if hasattr(s, 'init') and isinstance(s.init, Expression) else Reference(s.name, s.typ)
|
||||
return DefRegister(s.name, s.typ, clock, reset, init, s.info)
|
||||
elif isinstance(s, DefWire):
|
||||
typs[s.name] = s.typ
|
||||
return s
|
||||
elif isinstance(s, DefNode):
|
||||
value = infer_types_e(typs, s.value) if hasattr(s, 'value') and isinstance(s.value, Expression) else None
|
||||
typs[s.name] = s.value.typ
|
||||
return DefNode(s.name, value, s.info)
|
||||
elif isinstance(s, DefMemory):
|
||||
typs[s.name] = s.memType
|
||||
return s
|
||||
elif isinstance(s, DefInstance):
|
||||
typs[s.name] = mtyps[s.module]
|
||||
return s
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def infer_types_p(typs: set, p: Port) -> Port:
|
||||
typs[p.name] = p.typ
|
||||
return p
|
||||
|
||||
def infer_types(m: DefModule) -> DefModule:
|
||||
if isinstance(m, ExtModule):
|
||||
return m
|
||||
types: Dict[str, Type] = {}
|
||||
ports = None
|
||||
stmts = None
|
||||
if hasattr(m, 'ports') and isinstance(m.ports, list):
|
||||
ports = list(map(lambda p: infer_types_p(types, p), m.ports))
|
||||
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
stmts = list(map(lambda s: infer_types_s(types, s), m.body.stmts))
|
||||
|
||||
return Module(m.name, ports, Block(stmts), m.typ, m.info)
|
||||
|
||||
res = list(map(lambda m: infer_types(m), c.modules))
|
||||
return Circuit(res, c.main, c.info)
|
|
@ -1,100 +0,0 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from pyhcl.passes.wir import *
|
||||
|
||||
class InferWidths(Pass):
|
||||
@staticmethod
|
||||
def run(c: Circuit):
|
||||
def infer_widths_t(widths: Dict[str, Width], name: str, t: Type):
|
||||
if isinstance(t, BundleType):
|
||||
bs = list(map(lambda f: infer_widths_t(widths, name, f.typ), t.fields))
|
||||
bs = [b for b in bs if b is not None].pop()
|
||||
t, widths = bs[0], bs[1]
|
||||
|
||||
elif isinstance(t, MemoryType):
|
||||
return infer_widths_t(widths, name, t.typ)
|
||||
|
||||
elif isinstance(t, VectorType):
|
||||
return infer_widths_t(widths, name, t.typ)
|
||||
|
||||
elif isinstance(t, UIntType):
|
||||
if isinstance(t.width, UnknownWidth):
|
||||
return UIntType(widths.values().pop()), widths.clear()
|
||||
else:
|
||||
widths[name] = t.width
|
||||
elif isinstance(t, SIntType):
|
||||
if isinstance(t.width, UnknownWidth):
|
||||
return SIntType(widths.values().pop()), widths.clear()
|
||||
else:
|
||||
widths[name] = t.width
|
||||
|
||||
elif isinstance(t, ClockType):
|
||||
if isinstance(t.width, UnknownWidth):
|
||||
return ClockType(widths.values().pop()), widths.clear()
|
||||
else:
|
||||
widths[name] = t.width
|
||||
|
||||
elif isinstance(t, ResetType):
|
||||
if isinstance(t.width, UnknownWidth):
|
||||
return ResetType(widths.values().pop()), widths.clear()
|
||||
else:
|
||||
widths[name] = t.width
|
||||
|
||||
elif isinstance(t, AsyncResetType):
|
||||
if isinstance(t.width, UnknownWidth):
|
||||
return AsyncResetType(widths.values().pop()), widths.clear()
|
||||
else:
|
||||
widths[name] = t.width
|
||||
|
||||
return t, widths
|
||||
|
||||
def infer_widths_e(widths: Dict[str, Width], e: Expression):
|
||||
if isinstance(e, UIntLiteral):
|
||||
return e, widths
|
||||
elif isinstance(e, SIntLiteral):
|
||||
return e, widths
|
||||
elif isinstance(e, SubField):
|
||||
t, widths = infer_widths_t(widths, e.name, e.typ)
|
||||
return SubField(e.expr, e.name, t), widths
|
||||
elif isinstance(e, SubAccess):
|
||||
t, widths = infer_widths_t(widths, '', e.typ)
|
||||
return SubAccess(e.expr, e.index, t), widths
|
||||
elif isinstance(e, SubIndex):
|
||||
t, widths = infer_widths_t(widths, e.name, e.typ)
|
||||
return SubIndex(e.name, e.expr, e.value, t), widths
|
||||
elif isinstance(e, DoPrim):
|
||||
t, widths = infer_widths_t(widths, '', e.typ)
|
||||
return DoPrim(e.op, e.args, e.consts, t), widths
|
||||
elif isinstance(e, Mux):
|
||||
t, widths = infer_widths_t(widths, '', e.typ)
|
||||
return Mux(e.cond, e.tval, e.fval, t), widths
|
||||
elif isinstance(e, ValidIf):
|
||||
t, widths = infer_widths_t(widths, '', e.typ)
|
||||
return Mux(e.cond, e.value, t), widths
|
||||
else:
|
||||
t, widths = infer_widths_t(widths, e.name, e.typ)
|
||||
return Reference(e.name, t), widths
|
||||
|
||||
|
||||
def infer_widths_s(widths: Dict[str, Width], s: Statement):
|
||||
if isinstance(s, Connect):
|
||||
expr, widths = infer_widths_e(widths, s.expr)
|
||||
loc, widths = infer_widths_e(widths, s.loc)
|
||||
return Connect(loc, expr, s.info)
|
||||
return s
|
||||
|
||||
|
||||
def infer_widths_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, ExtModule):
|
||||
return m
|
||||
widths: Dict[str, Width] = {}
|
||||
stmts = None
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
stmts = list(map(lambda s: infer_widths_s(widths, s), m.body.stmts))
|
||||
|
||||
return Module(m.name, m.ports, Block(stmts), m.typ, m.info)
|
||||
|
||||
return Circuit(list(map(lambda m: infer_widths_m(m), c.modules)), c.main, c.info)
|
Loading…
Reference in New Issue