refactor: remove infer_types & infer_widths

This commit is contained in:
raybdzhou 2022-06-10 21:03:12 +08:00
parent db3333abda
commit c94f326cd7
5 changed files with 128 additions and 181 deletions

View File

@ -17,8 +17,8 @@ class FullAdder(Module):
if __name__ == '__main__':
# emit high firrtl
# Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
# emit lowered firrtl
# Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
# emit verilog
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")

View File

@ -1,11 +1,10 @@
from pyhcl.passes import check_form, check_types, check_widths, check_flows, infer_types, infer_widths
from pyhcl.passes import check_form, check_types, check_widths, check_flows, auto_inferring
from pyhcl.ir.low_ir import *
class CheckAndInfer:
@staticmethod
def run(c: Circuit):
c = check_form.CheckHighForm(c).run()
c = infer_types.InferTypes().run(c)
c = infer_widths.InferWidths().run(c)
c = auto_inferring.AutoInferring().run(c)
c = check_types.CheckTypes().run(c)
c = check_flows.CheckFlow().run(c)
c = check_widths.CheckWidths().run(c)

View File

@ -0,0 +1,124 @@
from dataclasses import dataclass
from typing import Dict, List
from pyhcl.ir.low_ir import *
@dataclass
class AutoInferring:
max_width: int = 0
def run(self, c: Circuit):
modules: List[Module] = []
def auto_inferring_t(t: Type) -> Type:
if isinstance(t, UIntType):
if t.width.width == 0:
return UIntType(IntWidth(self.max_width))
else:
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
return t
elif isinstance(t, SIntType):
if t.width.width == 0:
return SIntType(IntWidth(self.max_width))
else:
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
return t
elif isinstance(t, (ClockType, ResetType, AsyncResetType)):
return t
elif isinstance(t, VectorType):
return VectorType(auto_inferring_t(t.typ), t.size)
elif isinstance(t, MemoryType):
return MemoryType(auto_inferring_t(t.typ), t.size)
elif isinstance(t, BundleType):
return BundleType([Field(fx.name, fx.flip, auto_inferring_t(fx.typ)) for fx in t.fields])
else:
return t
def auto_inferring_e(e: Expression, inferring_map: Dict[str, Type]) -> Expression:
if isinstance(e, Mux):
return Mux(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.tval, inferring_map),
auto_inferring_e(e.fval, inferring_map), auto_inferring_t(e.typ))
elif isinstance(e, ValidIf):
return ValidIf(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.value, inferring_map), auto_inferring_t(e.typ))
elif isinstance(e, DoPrim):
return DoPrim(e.op, [auto_inferring_e(arg, inferring_map) for arg in e.args], e.consts, auto_inferring_t(e.typ))
elif isinstance(e, UIntLiteral):
if e.width.width < get_binary_width(e.value):
return UIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
else:
return e
elif isinstance(e, SIntLiteral):
if e.width.width < get_binary_width(e.value) + 1:
return SIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
else:
return e
elif isinstance(e, Reference):
typ = inferring_map[e.name] if inferring_map[e.name] else auto_inferring_t(e.typ)
return Reference(e.name, typ)
elif isinstance(e, SubField):
expr = auto_inferring_e(e.expr, inferring_map)
typ = e.typ
for fx in expr.typ.fields:
if fx.name == e.name:
typ = fx.typ
return SubField(expr, e.name, typ)
elif isinstance(e, SubIndex):
expr = auto_inferring_e(e.expr, inferring_map)
return SubIndex(e.name, expr, e.value, expr.typ.typ)
elif isinstance(e, SubAccess):
expr = auto_inferring_e(e.expr, inferring_map)
index = auto_inferring_e(e.index, inferring_map)
return SubAccess(expr, index, expr.typ.typ)
else:
return e
def auto_inferring_s(s: Statement, inferring_map: Dict[str, Type]) -> Statement:
if isinstance(s, Block):
stmts: List[Statement] = []
for sx in s.stmts:
stmts.append(auto_inferring_s(sx, inferring_map))
return Block(stmts)
elif isinstance(s, Conditionally):
return Conditionally(auto_inferring_e(s.pred, inferring_map), auto_inferring_s(s.conseq, inferring_map), auto_inferring_s(s.alt, inferring_map), s.info)
elif isinstance(s, DefRegister):
clock = auto_inferring_e(s.clock, inferring_map)
reset = auto_inferring_e(s.reset, inferring_map)
init = auto_inferring_e(s.init, inferring_map)
typ = auto_inferring_t(s.typ)
inferring_map[s.name] = typ
return DefRegister(s.name, typ, clock, reset, init, s.info)
elif isinstance(s, DefWire):
inferring_map[s.name] = auto_inferring_t(s.typ)
return s
elif isinstance(s, DefMemory):
inferring_map[s.name] = auto_inferring_t(s.memType)
return s
elif isinstance(s, DefNode):
value = auto_inferring_e(s.value, inferring_map)
inferring_map[s.name] = value.typ
return DefNode(s.name, value, s.info)
elif isinstance(s, DefMemPort):
clk = auto_inferring_e(s.clk, inferring_map)
index = auto_inferring_e(s.index, inferring_map)
return DefMemPort(s.name, s.mem, index, clk, s.rw, s.info)
elif isinstance(s, Connect):
return Connect(auto_inferring_e(s.loc, inferring_map), auto_inferring_e(s.expr, inferring_map), s.info, s.blocking, s.bidirection, s.mem)
else:
return s
def auto_inferring_m(m: DefModule, inferring_map: Dict[str, Type]) -> DefModule:
if isinstance(m, Module):
ports: List[Port] = []
for p in m.ports:
inferring_map[p.name] = auto_inferring_t(p.typ)
ports.append(Port(p.name, p.direction, inferring_map[p.name], p.info))
body = auto_inferring_s(m.body, inferring_map)
return Module(m.name, ports, body, m.typ, m.info)
else:
return m
for m in c.modules:
inferring_map: Dict[str, Type] = {}
modules.append(auto_inferring_m(m, inferring_map))
return Circuit(modules, c.main, c.info)

View File

@ -1,76 +0,0 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
from pyhcl.passes.wir import *
from pyhcl.passes.utils import module_type, field_type, sub_type, mux_type, get_or_else
class InferTypes(Pass):
def run(self, c: Circuit) -> Circuit:
mtyps: Dict[str, Type] = {}
for m in c.modules:
mtyps[m.name] = module_type(m)
def infer_types_e(typs: Dict[str, Type], e: Expression) -> Expression:
if isinstance(e, Reference):
return Reference(e.name, get_or_else(e.name in typs.keys(), typs[e.name], UnknownType))
elif isinstance(e, SubField):
return SubField(e.expr, e.name, field_type(e.expr.typ, e.name))
elif isinstance(e, SubIndex):
return SubIndex(e.name, e.expr, e.value, sub_type(e.expr.typ))
elif isinstance(e, SubAccess):
return SubAccess(e.expr, e.index, sub_type(e.expr.typ))
elif isinstance(e, DoPrim):
return DoPrim(e.op, e.args, e.consts, e.typ)
elif isinstance(e, Mux):
return Mux(e.cond, e.tval, e.fval, mux_type(e.tval, e.fval))
elif isinstance(e, ValidIf):
return ValidIf(e.cond, e.value, e.value.typ)
else:
return e
def infer_types_s(typs: Dict[str, Type], s: Statement) -> Statement:
if isinstance(s, DefRegister):
typs[s.name] = s.typ
clock = infer_types_e(typs, s.clock) if hasattr(s, 'clock') and isinstance(s.clock, Expression) else None
reset = s.reset if hasattr(s, 'reset') and isinstance(s.init, Expression) else Reference('reset', ResetType())
init = s.init if hasattr(s, 'init') and isinstance(s.init, Expression) else Reference(s.name, s.typ)
return DefRegister(s.name, s.typ, clock, reset, init, s.info)
elif isinstance(s, DefWire):
typs[s.name] = s.typ
return s
elif isinstance(s, DefNode):
value = infer_types_e(typs, s.value) if hasattr(s, 'value') and isinstance(s.value, Expression) else None
typs[s.name] = s.value.typ
return DefNode(s.name, value, s.info)
elif isinstance(s, DefMemory):
typs[s.name] = s.memType
return s
elif isinstance(s, DefInstance):
typs[s.name] = mtyps[s.module]
return s
else:
return s
def infer_types_p(typs: set, p: Port) -> Port:
typs[p.name] = p.typ
return p
def infer_types(m: DefModule) -> DefModule:
if isinstance(m, ExtModule):
return m
types: Dict[str, Type] = {}
ports = None
stmts = None
if hasattr(m, 'ports') and isinstance(m.ports, list):
ports = list(map(lambda p: infer_types_p(types, p), m.ports))
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
stmts = list(map(lambda s: infer_types_s(types, s), m.body.stmts))
return Module(m.name, ports, Block(stmts), m.typ, m.info)
res = list(map(lambda m: infer_types(m), c.modules))
return Circuit(res, c.main, c.info)

View File

@ -1,100 +0,0 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
from pyhcl.passes.wir import *
class InferWidths(Pass):
@staticmethod
def run(c: Circuit):
def infer_widths_t(widths: Dict[str, Width], name: str, t: Type):
if isinstance(t, BundleType):
bs = list(map(lambda f: infer_widths_t(widths, name, f.typ), t.fields))
bs = [b for b in bs if b is not None].pop()
t, widths = bs[0], bs[1]
elif isinstance(t, MemoryType):
return infer_widths_t(widths, name, t.typ)
elif isinstance(t, VectorType):
return infer_widths_t(widths, name, t.typ)
elif isinstance(t, UIntType):
if isinstance(t.width, UnknownWidth):
return UIntType(widths.values().pop()), widths.clear()
else:
widths[name] = t.width
elif isinstance(t, SIntType):
if isinstance(t.width, UnknownWidth):
return SIntType(widths.values().pop()), widths.clear()
else:
widths[name] = t.width
elif isinstance(t, ClockType):
if isinstance(t.width, UnknownWidth):
return ClockType(widths.values().pop()), widths.clear()
else:
widths[name] = t.width
elif isinstance(t, ResetType):
if isinstance(t.width, UnknownWidth):
return ResetType(widths.values().pop()), widths.clear()
else:
widths[name] = t.width
elif isinstance(t, AsyncResetType):
if isinstance(t.width, UnknownWidth):
return AsyncResetType(widths.values().pop()), widths.clear()
else:
widths[name] = t.width
return t, widths
def infer_widths_e(widths: Dict[str, Width], e: Expression):
if isinstance(e, UIntLiteral):
return e, widths
elif isinstance(e, SIntLiteral):
return e, widths
elif isinstance(e, SubField):
t, widths = infer_widths_t(widths, e.name, e.typ)
return SubField(e.expr, e.name, t), widths
elif isinstance(e, SubAccess):
t, widths = infer_widths_t(widths, '', e.typ)
return SubAccess(e.expr, e.index, t), widths
elif isinstance(e, SubIndex):
t, widths = infer_widths_t(widths, e.name, e.typ)
return SubIndex(e.name, e.expr, e.value, t), widths
elif isinstance(e, DoPrim):
t, widths = infer_widths_t(widths, '', e.typ)
return DoPrim(e.op, e.args, e.consts, t), widths
elif isinstance(e, Mux):
t, widths = infer_widths_t(widths, '', e.typ)
return Mux(e.cond, e.tval, e.fval, t), widths
elif isinstance(e, ValidIf):
t, widths = infer_widths_t(widths, '', e.typ)
return Mux(e.cond, e.value, t), widths
else:
t, widths = infer_widths_t(widths, e.name, e.typ)
return Reference(e.name, t), widths
def infer_widths_s(widths: Dict[str, Width], s: Statement):
if isinstance(s, Connect):
expr, widths = infer_widths_e(widths, s.expr)
loc, widths = infer_widths_e(widths, s.loc)
return Connect(loc, expr, s.info)
return s
def infer_widths_m(m: DefModule) -> DefModule:
if isinstance(m, ExtModule):
return m
widths: Dict[str, Width] = {}
stmts = None
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
stmts = list(map(lambda s: infer_widths_s(widths, s), m.body.stmts))
return Module(m.name, m.ports, Block(stmts), m.typ, m.info)
return Circuit(list(map(lambda m: infer_widths_m(m), c.modules)), c.main, c.info)