From c94f326cd708b75946b2d6175e37ba3a3e5416dd Mon Sep 17 00:00:00 2001 From: raybdzhou Date: Fri, 10 Jun 2022 21:03:12 +0800 Subject: [PATCH] refactor: remove infer_types & infer_widths --- main.py | 4 +- pyhcl/dsl/check_and_infer.py | 5 +- pyhcl/passes/auto_inferring.py | 124 +++++++++++++++++++++++++++++++++ pyhcl/passes/infer_types.py | 76 -------------------- pyhcl/passes/infer_widths.py | 100 -------------------------- 5 files changed, 128 insertions(+), 181 deletions(-) create mode 100644 pyhcl/passes/auto_inferring.py delete mode 100644 pyhcl/passes/infer_types.py delete mode 100644 pyhcl/passes/infer_widths.py diff --git a/main.py b/main.py index 95bf313..a123bdd 100644 --- a/main.py +++ b/main.py @@ -17,8 +17,8 @@ class FullAdder(Module): if __name__ == '__main__': # emit high firrtl - # Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir") + Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir") # emit lowered firrtl - # Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir") + Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir") # emit verilog Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v") diff --git a/pyhcl/dsl/check_and_infer.py b/pyhcl/dsl/check_and_infer.py index 025be20..19c8079 100644 --- a/pyhcl/dsl/check_and_infer.py +++ b/pyhcl/dsl/check_and_infer.py @@ -1,11 +1,10 @@ -from pyhcl.passes import check_form, check_types, check_widths, check_flows, infer_types, infer_widths +from pyhcl.passes import check_form, check_types, check_widths, check_flows, auto_inferring from pyhcl.ir.low_ir import * class CheckAndInfer: @staticmethod def run(c: Circuit): c = check_form.CheckHighForm(c).run() - c = infer_types.InferTypes().run(c) - c = infer_widths.InferWidths().run(c) + c = auto_inferring.AutoInferring().run(c) c = check_types.CheckTypes().run(c) c = check_flows.CheckFlow().run(c) c = check_widths.CheckWidths().run(c) diff --git a/pyhcl/passes/auto_inferring.py b/pyhcl/passes/auto_inferring.py new file mode 100644 index 0000000..595796e --- /dev/null +++ b/pyhcl/passes/auto_inferring.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import Dict, List +from pyhcl.ir.low_ir import * + + +@dataclass +class AutoInferring: + max_width: int = 0 + + def run(self, c: Circuit): + modules: List[Module] = [] + + def auto_inferring_t(t: Type) -> Type: + if isinstance(t, UIntType): + if t.width.width == 0: + return UIntType(IntWidth(self.max_width)) + else: + self.max_width = self.max_width if self.max_width > t.width.width else t.width.width + return t + elif isinstance(t, SIntType): + if t.width.width == 0: + return SIntType(IntWidth(self.max_width)) + else: + self.max_width = self.max_width if self.max_width > t.width.width else t.width.width + return t + elif isinstance(t, (ClockType, ResetType, AsyncResetType)): + return t + elif isinstance(t, VectorType): + return VectorType(auto_inferring_t(t.typ), t.size) + elif isinstance(t, MemoryType): + return MemoryType(auto_inferring_t(t.typ), t.size) + elif isinstance(t, BundleType): + return BundleType([Field(fx.name, fx.flip, auto_inferring_t(fx.typ)) for fx in t.fields]) + else: + return t + + def auto_inferring_e(e: Expression, inferring_map: Dict[str, Type]) -> Expression: + if isinstance(e, Mux): + return Mux(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.tval, inferring_map), + auto_inferring_e(e.fval, inferring_map), auto_inferring_t(e.typ)) + elif isinstance(e, ValidIf): + return ValidIf(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.value, inferring_map), auto_inferring_t(e.typ)) + elif isinstance(e, DoPrim): + return DoPrim(e.op, [auto_inferring_e(arg, inferring_map) for arg in e.args], e.consts, auto_inferring_t(e.typ)) + elif isinstance(e, UIntLiteral): + if e.width.width < get_binary_width(e.value): + return UIntLiteral(e.value, IntWidth(get_binary_width(e.value))) + else: + return e + elif isinstance(e, SIntLiteral): + if e.width.width < get_binary_width(e.value) + 1: + return SIntLiteral(e.value, IntWidth(get_binary_width(e.value))) + else: + return e + elif isinstance(e, Reference): + typ = inferring_map[e.name] if inferring_map[e.name] else auto_inferring_t(e.typ) + return Reference(e.name, typ) + elif isinstance(e, SubField): + expr = auto_inferring_e(e.expr, inferring_map) + typ = e.typ + for fx in expr.typ.fields: + if fx.name == e.name: + typ = fx.typ + return SubField(expr, e.name, typ) + elif isinstance(e, SubIndex): + expr = auto_inferring_e(e.expr, inferring_map) + return SubIndex(e.name, expr, e.value, expr.typ.typ) + elif isinstance(e, SubAccess): + expr = auto_inferring_e(e.expr, inferring_map) + index = auto_inferring_e(e.index, inferring_map) + return SubAccess(expr, index, expr.typ.typ) + else: + return e + + def auto_inferring_s(s: Statement, inferring_map: Dict[str, Type]) -> Statement: + if isinstance(s, Block): + stmts: List[Statement] = [] + for sx in s.stmts: + stmts.append(auto_inferring_s(sx, inferring_map)) + return Block(stmts) + elif isinstance(s, Conditionally): + return Conditionally(auto_inferring_e(s.pred, inferring_map), auto_inferring_s(s.conseq, inferring_map), auto_inferring_s(s.alt, inferring_map), s.info) + elif isinstance(s, DefRegister): + clock = auto_inferring_e(s.clock, inferring_map) + reset = auto_inferring_e(s.reset, inferring_map) + init = auto_inferring_e(s.init, inferring_map) + typ = auto_inferring_t(s.typ) + inferring_map[s.name] = typ + return DefRegister(s.name, typ, clock, reset, init, s.info) + elif isinstance(s, DefWire): + inferring_map[s.name] = auto_inferring_t(s.typ) + return s + elif isinstance(s, DefMemory): + inferring_map[s.name] = auto_inferring_t(s.memType) + return s + elif isinstance(s, DefNode): + value = auto_inferring_e(s.value, inferring_map) + inferring_map[s.name] = value.typ + return DefNode(s.name, value, s.info) + elif isinstance(s, DefMemPort): + clk = auto_inferring_e(s.clk, inferring_map) + index = auto_inferring_e(s.index, inferring_map) + return DefMemPort(s.name, s.mem, index, clk, s.rw, s.info) + elif isinstance(s, Connect): + return Connect(auto_inferring_e(s.loc, inferring_map), auto_inferring_e(s.expr, inferring_map), s.info, s.blocking, s.bidirection, s.mem) + else: + return s + + + def auto_inferring_m(m: DefModule, inferring_map: Dict[str, Type]) -> DefModule: + if isinstance(m, Module): + ports: List[Port] = [] + for p in m.ports: + inferring_map[p.name] = auto_inferring_t(p.typ) + ports.append(Port(p.name, p.direction, inferring_map[p.name], p.info)) + body = auto_inferring_s(m.body, inferring_map) + return Module(m.name, ports, body, m.typ, m.info) + else: + return m + + for m in c.modules: + inferring_map: Dict[str, Type] = {} + modules.append(auto_inferring_m(m, inferring_map)) + return Circuit(modules, c.main, c.info) \ No newline at end of file diff --git a/pyhcl/passes/infer_types.py b/pyhcl/passes/infer_types.py deleted file mode 100644 index 95dc2e1..0000000 --- a/pyhcl/passes/infer_types.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import List -from pyhcl.ir.low_ir import * -from pyhcl.ir.low_prim import * -from pyhcl.passes._pass import Pass -from pyhcl.passes.wir import * -from pyhcl.passes.utils import module_type, field_type, sub_type, mux_type, get_or_else - -class InferTypes(Pass): - def run(self, c: Circuit) -> Circuit: - mtyps: Dict[str, Type] = {} - for m in c.modules: - mtyps[m.name] = module_type(m) - - def infer_types_e(typs: Dict[str, Type], e: Expression) -> Expression: - if isinstance(e, Reference): - return Reference(e.name, get_or_else(e.name in typs.keys(), typs[e.name], UnknownType)) - elif isinstance(e, SubField): - return SubField(e.expr, e.name, field_type(e.expr.typ, e.name)) - elif isinstance(e, SubIndex): - return SubIndex(e.name, e.expr, e.value, sub_type(e.expr.typ)) - elif isinstance(e, SubAccess): - return SubAccess(e.expr, e.index, sub_type(e.expr.typ)) - elif isinstance(e, DoPrim): - return DoPrim(e.op, e.args, e.consts, e.typ) - elif isinstance(e, Mux): - return Mux(e.cond, e.tval, e.fval, mux_type(e.tval, e.fval)) - elif isinstance(e, ValidIf): - return ValidIf(e.cond, e.value, e.value.typ) - else: - return e - - def infer_types_s(typs: Dict[str, Type], s: Statement) -> Statement: - if isinstance(s, DefRegister): - typs[s.name] = s.typ - clock = infer_types_e(typs, s.clock) if hasattr(s, 'clock') and isinstance(s.clock, Expression) else None - reset = s.reset if hasattr(s, 'reset') and isinstance(s.init, Expression) else Reference('reset', ResetType()) - init = s.init if hasattr(s, 'init') and isinstance(s.init, Expression) else Reference(s.name, s.typ) - return DefRegister(s.name, s.typ, clock, reset, init, s.info) - elif isinstance(s, DefWire): - typs[s.name] = s.typ - return s - elif isinstance(s, DefNode): - value = infer_types_e(typs, s.value) if hasattr(s, 'value') and isinstance(s.value, Expression) else None - typs[s.name] = s.value.typ - return DefNode(s.name, value, s.info) - elif isinstance(s, DefMemory): - typs[s.name] = s.memType - return s - elif isinstance(s, DefInstance): - typs[s.name] = mtyps[s.module] - return s - else: - return s - - - def infer_types_p(typs: set, p: Port) -> Port: - typs[p.name] = p.typ - return p - - def infer_types(m: DefModule) -> DefModule: - if isinstance(m, ExtModule): - return m - types: Dict[str, Type] = {} - ports = None - stmts = None - if hasattr(m, 'ports') and isinstance(m.ports, list): - ports = list(map(lambda p: infer_types_p(types, p), m.ports)) - - if hasattr(m, 'body') and isinstance(m.body, Block): - if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list): - stmts = list(map(lambda s: infer_types_s(types, s), m.body.stmts)) - - return Module(m.name, ports, Block(stmts), m.typ, m.info) - - res = list(map(lambda m: infer_types(m), c.modules)) - return Circuit(res, c.main, c.info) \ No newline at end of file diff --git a/pyhcl/passes/infer_widths.py b/pyhcl/passes/infer_widths.py deleted file mode 100644 index faddaef..0000000 --- a/pyhcl/passes/infer_widths.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import List -from pyhcl.ir.low_ir import * -from pyhcl.ir.low_prim import * -from pyhcl.passes._pass import Pass -from pyhcl.passes.wir import * - -class InferWidths(Pass): - @staticmethod - def run(c: Circuit): - def infer_widths_t(widths: Dict[str, Width], name: str, t: Type): - if isinstance(t, BundleType): - bs = list(map(lambda f: infer_widths_t(widths, name, f.typ), t.fields)) - bs = [b for b in bs if b is not None].pop() - t, widths = bs[0], bs[1] - - elif isinstance(t, MemoryType): - return infer_widths_t(widths, name, t.typ) - - elif isinstance(t, VectorType): - return infer_widths_t(widths, name, t.typ) - - elif isinstance(t, UIntType): - if isinstance(t.width, UnknownWidth): - return UIntType(widths.values().pop()), widths.clear() - else: - widths[name] = t.width - elif isinstance(t, SIntType): - if isinstance(t.width, UnknownWidth): - return SIntType(widths.values().pop()), widths.clear() - else: - widths[name] = t.width - - elif isinstance(t, ClockType): - if isinstance(t.width, UnknownWidth): - return ClockType(widths.values().pop()), widths.clear() - else: - widths[name] = t.width - - elif isinstance(t, ResetType): - if isinstance(t.width, UnknownWidth): - return ResetType(widths.values().pop()), widths.clear() - else: - widths[name] = t.width - - elif isinstance(t, AsyncResetType): - if isinstance(t.width, UnknownWidth): - return AsyncResetType(widths.values().pop()), widths.clear() - else: - widths[name] = t.width - - return t, widths - - def infer_widths_e(widths: Dict[str, Width], e: Expression): - if isinstance(e, UIntLiteral): - return e, widths - elif isinstance(e, SIntLiteral): - return e, widths - elif isinstance(e, SubField): - t, widths = infer_widths_t(widths, e.name, e.typ) - return SubField(e.expr, e.name, t), widths - elif isinstance(e, SubAccess): - t, widths = infer_widths_t(widths, '', e.typ) - return SubAccess(e.expr, e.index, t), widths - elif isinstance(e, SubIndex): - t, widths = infer_widths_t(widths, e.name, e.typ) - return SubIndex(e.name, e.expr, e.value, t), widths - elif isinstance(e, DoPrim): - t, widths = infer_widths_t(widths, '', e.typ) - return DoPrim(e.op, e.args, e.consts, t), widths - elif isinstance(e, Mux): - t, widths = infer_widths_t(widths, '', e.typ) - return Mux(e.cond, e.tval, e.fval, t), widths - elif isinstance(e, ValidIf): - t, widths = infer_widths_t(widths, '', e.typ) - return Mux(e.cond, e.value, t), widths - else: - t, widths = infer_widths_t(widths, e.name, e.typ) - return Reference(e.name, t), widths - - - def infer_widths_s(widths: Dict[str, Width], s: Statement): - if isinstance(s, Connect): - expr, widths = infer_widths_e(widths, s.expr) - loc, widths = infer_widths_e(widths, s.loc) - return Connect(loc, expr, s.info) - return s - - - def infer_widths_m(m: DefModule) -> DefModule: - if isinstance(m, ExtModule): - return m - widths: Dict[str, Width] = {} - stmts = None - if hasattr(m, 'body') and isinstance(m.body, Block): - if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list): - stmts = list(map(lambda s: infer_widths_s(widths, s), m.body.stmts)) - - return Module(m.name, m.ports, Block(stmts), m.typ, m.info) - - return Circuit(list(map(lambda m: infer_widths_m(m), c.modules)), c.main, c.info) \ No newline at end of file