Merge pull request #4 from raybdzhou/master

new features
This commit is contained in:
Mosk0ng 2022-07-16 14:05:30 +08:00 committed by GitHub
commit 6c20d181b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 3840 additions and 345 deletions

10
.gitignore vendored
View File

@ -10,6 +10,7 @@ simulation/
# firrtl
.fir/
.v/
# obj
obj_dir/*
@ -122,10 +123,7 @@ venv.bak/
temp
# jar
*.jar
main.py
# docs
/docs/_site/
/docs/Gemfile.lock
.jekyll-cache

100
README.md
View File

@ -9,8 +9,11 @@ and dynamically typed objects.
The goal of PyHCL is providing a complete design and verification tool flow for heterogeneous computing systems flexibly using the same design methodology.
PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design. With the FIRRTL
compiler framework, PyHCL-generated circuits can be compiled to the widely-used HDL Verilog.
PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design.
PyHCL-generated circuits can be compiled to the widely-used HDL Verilog.
Attention: The back end of the compilation is highly experimental.
## Getting Started
@ -36,11 +39,11 @@ class FullAdder(Module):
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
```
#### Compiling To FIRRTL
#### Compiling To High FIRRTL
Compiling module by calling `compile_to_firrtl`:
Compiling module by calling `compile_to_highform`:
```python
Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir")
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
```
Will generate the following FIRRTL codes:
@ -49,55 +52,74 @@ circuit FullAdder :
module FullAdder :
input clock : Clock
input reset : UInt<1>
input FullAdder_io_a : UInt<1>
input FullAdder_io_b : UInt<1>
input FullAdder_io_cin : UInt<1>
output FullAdder_io_sum : UInt<1>
output FullAdder_io_cout : UInt<1>
output io : {flip a : UInt<1>, flip b : UInt<1>, flip cin : UInt<1>, s : UInt<1>, cout : UInt<1>}
node _T_0 = xor(FullAdder_io_a, FullAdder_io_b)
node _T_1 = xor(_T_0, FullAdder_io_cin)
FullAdder_io_sum <= _T_1
node _T_2 = and(FullAdder_io_a, FullAdder_io_b)
node _T_3 = and(FullAdder_io_b, FullAdder_io_cin)
node _T = xor(io.a, io.b)
node _T_1 = xor(_T, io.cin)
io.s <= _T_1
node _T_2 = and(io.a, io.b)
node _T_3 = and(io.a, io.cin)
node _T_4 = or(_T_2, _T_3)
node _T_5 = and(FullAdder_io_a, FullAdder_io_cin)
node _T_5 = and(io.b, io.cin)
node _T_6 = or(_T_4, _T_5)
FullAdder_io_cout <= _T_6
io.cout <= _T_6
```
#### Compiling To Lowered FIRRTL
Compiling module by calling `compile_to_lowform`:
```python
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
```
Will generate the following FIRRTL codes:
```
circuit FullAdder :
module FullAdder :
input clock : Clock
input reset : UInt<1>
input io_a : UInt<1>
input io_b : UInt<1>
input io_cin : UInt<1>
output io_s : UInt<1>
output io_cout : UInt<1>
node _T = xor(io_a, io_b)
node _T_1 = xor(_T, io_cin)
io_s <= _T_1
node _T_2 = and(io_a, io_b)
node _T_3 = and(io_a, io_cin)
node _T_4 = or(_T_2, _T_3)
node _T_5 = and(io_b, io_cin)
node _T_6 = or(_T_4, _T_5)
io_cout <= _T_6
```
#### Compiling To Verilog
While FIRRTL is generated, PyHCL's job is complete. To further compile to Verilog, the [FIRRTL compiler framework](
https://github.com/freechipsproject/firrtl) is required:
Compiling module by calling `compile_to_verilog`:
```shell script
Emitter.dumpVerilog(Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir"))
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
```
Then `FullAdder.v` will be generated:
```verilog
module FullAdder(
input clock,
input reset,
input FullAdder_io_a,
input FullAdder_io_b,
input FullAdder_io_cin,
output FullAdder_io_sum,
output FullAdder_io_cout
input clock,
input reset,
input io_a,
input io_b,
input io_cin,
output io_s,
output io_cout
);
wire _T_0;
wire _T_2;
wire _T_3;
wire _T_4;
wire _T_5;
assign _T_0 = FullAdder_io_a ^ FullAdder_io_b;
assign _T_2 = FullAdder_io_a & FullAdder_io_b;
assign _T_3 = FullAdder_io_b & FullAdder_io_cin;
assign _T_4 = _T_2 | _T_3;
assign _T_5 = FullAdder_io_a & FullAdder_io_cin;
assign FullAdder_io_sum = _T_0 ^ FullAdder_io_cin;
assign FullAdder_io_cout = _T_4 | _T_5;
assign io_s = io_a ^ io_b ^ io_cin;
assign io_cout = io_a & io_b | io_a & io_cin | io_b & io_cin;
endmodule
```

View File

@ -15,4 +15,4 @@ class FullAdder(Module):
if __name__ == '__main__':
Emitter.dumpVerilog(Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir"))
Emitter.dump(Emitter.emit(FullAdder(), True), "FullAdder.v")

73
main.py
View File

@ -1,61 +1,24 @@
from pyhcl import *
W = 8 # 位宽
class FullAdder(Module):
io = IO(
a=Input(Bool),
b=Input(Bool),
cin=Input(Bool),
sum=Output(Bool),
cout=Output(Bool),
)
# Generate the sum
io.sum <<= io.a ^ io.b ^ io.cin
def matrixMul(x: int, y: int, z: int):
"""
x*y × y*z 矩阵乘法电路
"""
class MatrixMul(Module):
io = IO(
a=Input(Vec(x, Vec(y, U.w(W)))),
b=Input(Vec(y, Vec(z, U.w(W)))),
o=Output(Vec(x, Vec(z, U.w(W)))),
)
for i, a_row in enumerate(io.a):
for j, b_col in enumerate(zip(*io.b)):
io.o[i][j] <<= Sum(a * b for a, b in zip(a_row, b_col))
return MatrixMul()
def bias(n):
return U.w(W)(n)
def weight(lst):
return VecInit(U.w(W)(i) for i in lst)
def neurons(w, b):
"""
参数权重向量 w偏移量 b
输出神经网络神经元电路 *暂无通过非线性传递函数
"""
class Unit(Module):
io = IO(
i=Input(Vec(len(w), U.w(W))),
o=Output(U.w(W))
)
m = matrixMul(1, len(w), 1).io
m.a <<= io.i
m.b <<= w
io.o <<= m.o[0][0] + b
return Unit()
def main():
# 得到权重向量为[3, 4, 5, 6, 7, 8, 9, 10]偏移量为14的神经元电路
n = neurons(weight([3, 4, 5, 6, 7, 8, 9, 10]), bias(14))
f = Emitter.dump(Emitter.emit(n), "neurons.fir")
#Emitter.dumpVerilog(f)
# Generate the carry
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
if __name__ == '__main__':
main()
# emit high firrtl
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
# emit lowered firrtl
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
# emit verilog
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")

View File

@ -10,3 +10,4 @@ from .funcs import CatVecL2H, CatVecH2L, CatBits, OneDimensionalization, Sum, De
from .memory import Mem
from .clockdomin import clockdomin
from .verifaction import doAssert, doAssume, doCover
from .stage import Form, HighForm, MidForm, LowForm, Verilog

View File

@ -7,20 +7,10 @@ from pyhcl.core._emit_context import EmitterContext
from pyhcl.dsl.module import Module
from pyhcl.ir import low_ir
from pyhcl.util.firrtltools import replacewithfirmod
from pyhcl.dsl.stage import Form, HighForm
class Emitter:
# 传入模块对象返回str---firrtl代码
@staticmethod
def emit(m: Module, toverilog=False) -> str:
circuit = Emitter.elaborate(m)
# 将Circuit对象转化为str
if(toverilog):
return circuit.verilog_serialize()
else:
return circuit.serialize() # firrtl代码
# 传入模块对象返回Circuit对象
@staticmethod
def elaborate(m: Module) -> low_ir.Circuit:
ec: EmitterContext = EmitterContext(m, {}, Counter())
@ -30,19 +20,39 @@ class Emitter:
DynamicContext.clearScope()
return circuit
# 传入firrtl代码和文件名将firrtl代码写入文件中并返回文件路径
@staticmethod
def emit(m: Module, f: Form = HighForm) -> str:
return f(Emitter.elaborate(m)).emit()
@staticmethod
def dump(s, filename) -> str:
if not os.path.exists('.fir'):
os.mkdir('.fir')
dir_name = "." + filename.split(".")[-1]
if not os.path.exists(dir_name):
os.mkdir(dir_name)
f = os.path.join('.fir', filename)
f = os.path.join(dir_name, filename)
with open(f, "w+") as fir_file:
fir_file.write(s)
return f
# 传入firrtl文件路径执行firrtl命令将firrtl代码编译为verilog代码
@staticmethod
def dumpVerilog(filename):
os.system('firrtl -i %s -o %s -X verilog' % (filename, filename))
def dumpVerilog(filename, use_jar=False):
if use_jar:
os.system('java -jar firrtl.jar -i %s -o %s -X verilog' % (filename, filename))
else:
os.system('firrtl -i %s -o %s -X verilog' % (filename, filename))
@staticmethod
def dumpMidForm(filename, use_jar=False):
if use_jar:
os.system('java -jar firrtl.jar -i %s -o %s -X middle' % (filename, filename))
else:
os.system('firrtl -i %s -o %s -X middle' % (filename, filename))
@staticmethod
def dumpLoweredForm(filename, use_jar):
if use_jar:
os.system('java -jar firrtl.jar -i %s -o %s -X low' % (filename, filename))
else:
os.system('firrtl -i %s -o %s -X low' % (filename, filename))

81
pyhcl/dsl/stage.py Normal file
View File

@ -0,0 +1,81 @@
from abc import ABC, abstractclassmethod
from dataclasses import dataclass
from pyhcl.ir import low_ir
from pyhcl.passes.check_form import CheckHighForm
from pyhcl.passes.check_types import CheckTypes
from pyhcl.passes.check_flows import CheckFlow
from pyhcl.passes.check_widths import CheckWidths
from pyhcl.passes.auto_inferring import AutoInferring
from pyhcl.passes.replace_subaccess import ReplaceSubaccess
from pyhcl.passes.expand_aggregate import ExpandAggregate
from pyhcl.passes.expand_whens import ExpandWhens
from pyhcl.passes.expand_memory import ExpandMemory
from pyhcl.passes.optimize import Optimize
from pyhcl.passes.verilog_optimize import VerilogOptimize
from pyhcl.passes.remove_access import RemoveAccess
from pyhcl.passes.expand_sequential import ExpandSequential
from pyhcl.passes.handle_instance import HandleInstance
from pyhcl.passes.utils import AutoName
class Form(ABC):
@abstractclassmethod
def emit(self) -> str:
...
@dataclass
class HighForm(Form):
c: low_ir.Circuit
def emit(self) -> str:
self.c = CheckHighForm(self.c).run()
self.c = AutoInferring().run(self.c)
self.c = CheckTypes().run(self.c)
self.c = CheckFlow().run(self.c)
self.c = CheckWidths().run(self.c)
return self.c.serialize()
@dataclass
class MidForm(Form):
def emit(self) -> str:
...
@dataclass
class LowForm(Form):
c: low_ir.Circuit
def emit(self) -> str:
AutoName()
self.c = CheckHighForm(self.c).run()
self.c = AutoInferring().run(self.c)
self.c = CheckTypes().run(self.c)
self.c = CheckFlow().run(self.c)
self.c = CheckWidths().run(self.c)
self.c = ExpandMemory().run(self.c)
self.c = ReplaceSubaccess().run(self.c)
self.c = ExpandAggregate().run(self.c)
self.c = RemoveAccess().run(self.c)
self.c = ExpandWhens().run(self.c)
self.c = HandleInstance().run(self.c)
self.c = Optimize().run(self.c)
return self.c.serialize()
@dataclass
class Verilog(Form):
c: low_ir.Circuit
def emit(self) -> str:
AutoName()
self.c = CheckHighForm(self.c).run()
self.c = AutoInferring().run(self.c)
self.c = CheckTypes().run(self.c)
self.c = CheckFlow().run(self.c)
self.c = CheckWidths().run(self.c)
self.c = ExpandAggregate().run(self.c)
self.c = ReplaceSubaccess().run(self.c)
self.c = RemoveAccess().run(self.c)
self.c = VerilogOptimize().run(self.c)
self.c = ExpandSequential().run(self.c)
self.c = HandleInstance().run(self.c)
self.c = Optimize().run(self.c)
return self.c.verilog_serialize()

View File

@ -84,7 +84,7 @@ class VecInit(Node):
# Connect Elements
for i, node in enumerate(self.lst):
for idx, elem in self.subIdxs(low_ir.SubIndex(ref, i, typ.typ), node, ctx):
for idx, elem in self.subIdxs(low_ir.SubIndex('', ref, i, typ.typ), node, ctx):
con = low_ir.Connect(idx, elem)
ctx.appendFinalStatement(con, self.scopeId)
@ -93,7 +93,7 @@ class VecInit(Node):
def subIdxs(self, idx, node, ctx):
if isinstance(node, VecInit):
return [(nIdx, elem) for i, n in enumerate(node.lst)
for nIdx, elem in node.subIdxs(low_ir.SubIndex(idx, i, node.typ.mapToIR(ctx)), n, ctx)]
for nIdx, elem in node.subIdxs(low_ir.SubIndex('', idx, i, node.typ.mapToIR(ctx)), n, ctx)]
else:
return [(idx, node.mapToIR(ctx))]

View File

@ -2,13 +2,11 @@ from __future__ import annotations
from abc import ABC
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Dict, List, Optional
from pyhcl.ir.low_node import FirrtlNode
from pyhcl.ir.low_prim import PrimOp, Bits
from pyhcl.ir.utils import backspace, indent, deleblankline, backspace1
from pyhcl.ir.low_prim import PrimOp, Bits, Cat
from pyhcl.ir.utils import backspace, indent, deleblankline, backspace1, get_binary_width, TransformException, DAG
class Info(FirrtlNode, ABC):
"""INFOs"""
def serialize(self) -> str:
@ -58,7 +56,11 @@ class Expression(FirrtlNode, ABC):
class Type(FirrtlNode, ABC):
"""TYPEs"""
...
def map_type(self, typ):
...
def map_width(self, typ):
...
@dataclass(frozen=True, init=False)
@ -69,7 +71,6 @@ class UnknownType(Type):
def verilog_serialize(self) -> str:
return self.serialize()
@dataclass(frozen=True)
class Reference(Expression):
name: str
@ -92,11 +93,12 @@ class SubField(Expression):
return f"{self.expr.serialize()}.{self.name}"
def verilog_serialize(self) -> str:
return f"{self.expr.verilog_serialize()}_{self.name}"
return self.serialize()
@dataclass(frozen=True)
class SubIndex(Expression):
name: str
expr: Expression
value: int
typ: Type
@ -105,7 +107,7 @@ class SubIndex(Expression):
return f"{self.expr.serialize()}[{self.value}]"
def verilog_serialize(self) -> str:
return f"{self.expr.verilog_serialize()}[{self.value}]"
return self.serialize()
@dataclass(frozen=True)
@ -117,9 +119,8 @@ class SubAccess(Expression):
def serialize(self) -> str:
return f"{self.expr.serialize()}[{self.index.serialize()}]"
def verilog_serialize(self) -> str:
return f"{self.expr.verilog_serialize()}[{self.index.verilog_serialize()}]"
def verilog_serialize(self):
return self.serialize()
@dataclass(frozen=True)
class Mux(Expression):
@ -134,6 +135,18 @@ class Mux(Expression):
def verilog_serialize(self) -> str:
return f"{self.cond.verilog_serialize()} ? {self.tval.verilog_serialize()} : {self.fval.verilog_serialize()}"
@dataclass(frozen=True)
class ValidIf(Expression):
cond: Expression
value: Expression
typ: Type
def serialize(self) -> str:
return f"validif({self.cond.serialize()}, {self.value.serialize()})"
def verilog_serialize(self) -> str:
return f"{self.cond.verilog_serialize()} ? {self.value.verilog_serialize()} : Z"
@dataclass(frozen=True)
class DoPrim(Expression):
@ -148,7 +161,18 @@ class DoPrim(Expression):
def verilog_serialize(self) -> str:
sl: List[str] = [arg.verilog_serialize() for arg in self.args] + [repr(con) for con in self.consts]
return f'{self.op.verilog_serialize().join(sl)}'
if isinstance(self.op, Cat):
return f'{{{", ".join(sl)}}}'
elif isinstance(self.op, Bits):
arg = self.args[0]
msb, lsb = self.consts[0], self.consts[1]
msb_lsb = f'{msb}: {lsb}' if msb != lsb else f'{msb}'
return f'{arg.verilog_serialize()}[{msb_lsb}]'
if len(sl) > 1:
return f'{self.op.verilog_serialize().join(sl)}'
else:
return f'{self.op.verilog_serialize()}{sl.pop()}'
@dataclass(frozen=True)
@ -175,10 +199,10 @@ class UnknownWidth(Width):
width: int = None
def serialize(self) -> str:
return ""
return ''
def verilog_serialize(self) -> str:
return ""
return ''
@dataclass(frozen=True, init=False)
@ -291,8 +315,7 @@ class Flip(Orientation):
return 'flip '
def verilog_serialize(self) -> str:
# no "flip" in verilog, bug here
return 'flip'
return ''
@dataclass(frozen=True)
@ -306,7 +329,7 @@ class Field(FirrtlNode):
return f'{self.flip.serialize()}{self.name} : {self.typ.serialize()}'
def verilog_serialize(self) -> str:
return f'{self.flip.verilog_serialize()}{self.typ.verilog_serialize()}\t${self.name}'
return ''
@dataclass(frozen=True)
@ -317,14 +340,7 @@ class BundleType(AggregateType):
return '{' + ', '.join([f.serialize() for f in self.fields]) + '}'
def verilog_serialize(self) -> list:
field_list = []
for f in self.fields:
if type(f.typ.verilog_serialize()) is list:
for t in f.typ.verilog_serialize():
field_list.append(f'${f.name}{t}')
else:
field_list.append(f.verilog_serialize())
return field_list
return ''
@dataclass(frozen=True)
@ -335,12 +351,13 @@ class VectorType(AggregateType):
def serialize(self) -> str:
return f'{self.typ.serialize()}[{self.size}]'
def verilog_serialize(self) -> list:
return [f'_{v}' for v in range(self.size)]
def verilog_serialize(self):
return ''
def irWithIndex(self, index):
if isinstance(index, int):
return lambda _: {"ir": SubIndex(_, index, self.typ)}
return lambda _: {"ir": SubIndex('', _, index, self.typ)}
else:
return lambda _: {"ir": SubAccess(_, index, self.typ)}
@ -354,7 +371,7 @@ class MemoryType(AggregateType):
return f'{self.typ.serialize()}[{self.size}]'
def verilog_serialize(self) -> str:
return [f'_{v}' for v in range(self.size)]
return f'reg {self.typ.verilog_serialize()} m [0:{self.size-1}];'
def irWithIndex(self, index):
return lambda _: {"ir": lambda name, mem, clk, rw: DefMemPort(name, mem, index, clk, rw), "inPort": True}
@ -368,7 +385,7 @@ class ClockType(GroundType):
return 'Clock'
def verilog_serialize(self) -> str:
return ""
return ''
@dataclass(frozen=True, init=False)
@ -379,8 +396,7 @@ class ResetType(GroundType):
return "UInt<1>"
def verilog_serialize(self) -> str:
# Todo
return "Reset"
return ''
@dataclass(frozen=True, init=False)
@ -391,8 +407,7 @@ class AsyncResetType(GroundType):
return "AsyncReset"
def verilog_serialize(self) -> str:
# Todo
return "AsyncReset"
return ''
class Direction(FirrtlNode, ABC):
@ -434,20 +449,7 @@ class Port(FirrtlNode):
return f'{self.direction.serialize()} {self.name} : {self.typ.serialize()}{self.info.serialize()}'
def verilog_serialize(self) -> str:
if type(self.typ.verilog_serialize()) is str:
return f'{self.direction.verilog_serialize()}\t{self.typ.verilog_serialize()}\t{self.name},\n'
else:
portdeclares = ''
seq = self.typ.verilog_serialize()
for s in seq:
ns = s.replace('$', f'{self.name}_')
if "flip" in ns:
ns = ns.replace('flip', "")
portdeclares += f'{self.direction.verilog_serialize(True)}\t{ns},\n'
else:
portdeclares += f'{self.direction.verilog_serialize()}\t{ns},\n'
return portdeclares
return f'{self.direction.verilog_serialize()}\t{self.typ.verilog_serialize()}\t{self.name},\t{self.info.verilog_serialize()}'
class Statement(FirrtlNode, ABC):
...
@ -459,7 +461,7 @@ class EmptyStmt(Statement):
return 'skip'
def verilog_serialize(self) -> str:
return '// skip'
return ''
@dataclass(frozen=True)
@ -472,7 +474,7 @@ class DefWire(Statement):
return f'wire {self.name} : {self.typ.serialize()}{self.info.serialize()}'
def verilog_serialize(self) -> str:
return f'wire\t{self.typ.verilog_serialize()}\t{self.name}{self.info.verilog_serialize()};'
return f'wire\t{self.typ.verilog_serialize()}\t{self.name};\t{self.info.verilog_serialize()}'
@dataclass(frozen=True)
@ -485,15 +487,16 @@ class DefRegister(Statement):
info: Info = NoInfo()
def serialize(self) -> str:
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
if self.init is not None else ""
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
if self.reset is not None else ""
if self.init:
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
if self.init is not None else ""
else:
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
if self.reset is not None else ""
return f'reg {self.name} : {self.typ.serialize()}, {self.clock.serialize()}{i}{self.info.serialize()}'
def verilog_serialize(self) -> str:
return f'reg {self.typ.verilog_serialize()}\t{self.name}{self.info.verilog_serialize()};'
return f'reg\t{self.typ.verilog_serialize()}\t{self.name};\t{self.info.verilog_serialize()}'
@dataclass(frozen=True)
class DefInstance(Statement):
@ -506,36 +509,40 @@ class DefInstance(Statement):
return f'inst {self.name} of {self.module}{self.info.serialize()}'
def verilog_serialize(self) -> str:
instdeclares = ''
instdeclares: List[str] = []
portdeclares: List[str] = []
for p in self.ports:
instdeclares += f'\n.{p.name}({self.name}_{p.name}),'
return f'{self.module} {self.name} ({instdeclares}\n);'
portdeclares.append(f'wire\t{p.typ.verilog_serialize()}\t{self.name}_{p.name};')
instdeclares.append(indent(f'\n.{p.name}({self.name}_{p.name}),'))
port_decs = '\n'.join(portdeclares)
return f"{port_decs}\n{self.module}\t{self.name}(\t{self.info.verilog_serialize()}{''.join(instdeclares)});"
@dataclass(frozen=True)
class WDefMemory(Statement):
name: str
memType: MemoryType
dataType: Type
depth: int
writeLatency: int
readLatency: int
readers: List[str]
writers: List[str]
readUnderWrite: Optional[str] = None
info: Info = NoInfo()
# @dataclass(frozen=True)
# class DefMemory(Statement):
# name: str
# dataType: Type
# depth: int
# writeLatency: int
# readLatency: int
# readers: List[str]
# writers: List[str]
# readWriters: List[str]
# readUnderWrite: Optional[str] = None
# info: Info = NoInfo()
#
# def serialize(self) -> str:
# lst = [f"\ndata-type => {self.dataType.serialize()}",
# f"depth => {self.depth}",
# f"read-latency => {self.readLatency}",
# f"write-latency => {self.writeLatency}"] + \
# [f"reader => {r}" for r in self.readers] + \
# [f"writer => {w}" for w in self.writers] + \
# [f"readwriter => {rw}" for rw in self.readWriters] + \
# ['read-under-write => undefined']
# s = indent('\n'.join(lst))
# return f'mem {self.name} :{self.info.serialize()}{s}'
def serialize(self) -> str:
lst = [f"\ndata-type => {self.dataType.serialize()}",
f"depth => {self.depth}",
f"read-latency => {self.readLatency}",
f"write-latency => {self.writeLatency}"] + \
[f"reader => {r}" for r in self.readers] + \
[f"writer => {w}" for w in self.writers] + \
['read-under-write => undefined']
s = indent('\n'.join(lst))
return f'mem {self.name} :{self.info.serialize()}{s}'
def verilog_serialize(self) -> str:
return self.memType.verilog_serialize()
@dataclass(frozen=True)
class DefMemory(Statement):
@ -547,10 +554,7 @@ class DefMemory(Statement):
return f'cmem {self.name} : {self.memType.serialize()}{self.info.serialize()}'
def verilog_serialize(self) -> str:
memorydeclares = ''
memorydeclares += f'{self.memType.verilog_serialize()}\n'
memorydeclares += f'{self.info.verilog_serialize()}\n'
return memorydeclares
return f'{self.memType.verilog_serialize()}\t{self.info.verilog_serialize()}'
@dataclass(frozen=True)
@ -563,7 +567,7 @@ class DefNode(Statement):
return f'node {self.name} = {self.value.serialize()}{self.info.serialize()}'
def verilog_serialize(self) -> str:
return f'wire {self.value.typ.verilog_serialize()} {self.name} = {self.value.verilog_serialize()}{self.info.verilog_serialize()};'
return f'wire\t{self.value.typ.verilog_serialize()}\t{self.name} = {self.value.verilog_serialize()};\t{self.info.verilog_serialize()}'
@dataclass(frozen=True)
@ -582,12 +586,15 @@ class DefMemPort(Statement):
def verilog_serialize(self) -> str:
memportdeclares = ''
memportdeclares += indent(f'{self.mem.verilog_serialize()}')
# TODO
memportdeclares += f'wire {self.mem.typ.typ.verilog_serialize()} {self.mem.verilog_serialize()}_{self.name}_data;\n'
memportdeclares += f'wire [{get_binary_width(self.mem.typ.size)-1}:0] {self.mem.verilog_serialize()}_{self.name}_addr;\n'
memportdeclares += f'wire {self.mem.verilog_serialize()}_{self.name}_en;\n'
memportdeclares += f'assign {self.mem.verilog_serialize()}_{self.name}_addr = {get_binary_width(self.mem.typ.size)}\'h{self.index.value};\n'
memportdeclares += f'assign {self.mem.verilog_serialize()}_{self.name}_en = 1\'h1;\n'
if self.rw is False:
memportdeclares += f'wire {self.mem.verilog_serialize()}_{self.name}_mask;\n'
return memportdeclares
# stmt pass will change the conseq and alt
@dataclass(frozen=False)
class Conditionally(Statement):
pred: Expression
@ -602,132 +609,40 @@ class Conditionally(Statement):
def verilog_serialize(self) -> str:
s = indent(f'\n{self.conseq.verilog_serialize()}') + \
('' if self.alt == EmptyStmt() else '\nelse' + indent(f'\n{self.alt.verilog_serialize()}'))
return f'if ({self.pred.verilog_serialize()}) {self.info.verilog_serialize()}{s}'
class ValTracer:
...
class RegTracer(ValTracer):
def __init__(self, reg: DefRegister):
self.stmt = reg
self.clock = reg.clock
self.reset = reg.reset
self.conds = {}
def add_cond(self, signal: str, action):
if signal not in self.conds:
self.conds[signal] = [action]
else:
self.conds[signal].append(action)
def gen_body(self):
stmts = []
final_map = {}
for k, v in self.conds.items():
if(k == "default"):
stmts.append(Block(self.conds[k]))
else:
stmts.append(Conditionally(Reference(k, UIntType(IntWidth(1))), Block(self.conds[k]), EmptyStmt()))
self.body = Block(stmts)
def gen_always_block(self):
self.gen_body()
res = f'always @(posedge {self.clock.verilog_serialize()}) begin\n' + \
deleblankline(indent(f'\n{self.body.verilog_serialize()}')) + \
f'\nend'
return res
class PassManager:
def __init__(self, target_block):
self.block = target_block
self.reg_map = {}
self.define_pass()
def renew(self):
return self.stmts_pass(self.block)
# find out all reg definition and its clock and reset infos
def define_pass(self):
reg_tracers = [RegTracer(stmt) for stmt in self.block.stmts if type(stmt) == DefRegister]
self.reg_map = {x.stmt.name: x for x in reg_tracers}
# nest
def stmts_pass(self, block, signal: str = "default") -> Block:
if type(block) == EmptyStmt:
return EmptyStmt()
for stmt in block.stmts:
if type(stmt) == Connect and stmt.loc.name in self.reg_map:
stmt.blocking = False
self.reg_map[stmt.loc.name].add_cond(signal, stmt)
elif type(stmt) == Conditionally:
stmt.conseq = self.stmts_pass(stmt.conseq, stmt.pred.verilog_serialize())
stmt.alt = self.stmts_pass(stmt.alt, "!" + stmt.pred.verilog_serialize())
else:
pass
stmts = [stmt for stmt in block.stmts if self.pass_check(stmt)]
return Block(stmts) if stmts else EmptyStmt()
def pass_check(self, stmt):
return type(stmt) != Conditionally and type(stmt) != Connect \
or type(stmt) == Connect and stmt.loc.name not in self.reg_map \
or type(stmt) == Conditionally and type(stmt.conseq) != EmptyStmt
def gen_all_always_block(self):
res = ""
for reg, tracer in self.reg_map.items():
res += f'\n// handle register {reg}'
res += f'\n{tracer.gen_always_block()}'
return res
s = indent(f"\n{self.conseq.verilog_serialize()}") + "\nend" + \
("" if self.alt == EmptyStmt() else "\nelse begin" + indent(f"\n{self.alt.verilog_serialize()}") + "\nend")
return f"if ({self.pred.verilog_serialize()}) begin\t{self.info.verilog_serialize()}{s}"
@dataclass(frozen=True)
class Block(Statement):
stmts: List[Statement]
"""
def serialize(self) -> str:
return '\n'.join([stmt.serialize() for stmt in self.stmts]) if self.stmts else ""
"""
def auto_gen_node(self, stmt):
return isinstance(stmt, DefNode) and stmt.name.startswith("_T")
# use less nodes
def serialize(self) -> str:
if not self.stmts:
return ""
node_exp_map = {stmt.name: stmt for stmt in self.stmts if self.auto_gen_node(stmt)}
# replace all reference in node_exp_map
for k, v in node_exp_map.items():
if isinstance(v.value, DoPrim):
args = v.value.args
cnt = 0
for arg in args:
if isinstance(arg, Reference) and arg.name in node_exp_map:
node_exp_map[k].value.args[cnt] = node_exp_map[arg.name].value
cnt += 1
# replace all reference in connect
for stmt in self.stmts:
if isinstance(stmt, Connect) and isinstance(stmt.expr, Reference) and stmt.expr.name in node_exp_map:
stmt.expr = node_exp_map[stmt.expr.name].value
return '\n'.join([stmt.serialize() for stmt in self.stmts if not self.auto_gen_node(stmt)]) if self.stmts else ""
def verilog_serialize(self) -> str:
manager = PassManager(self)
new_blocks = manager.renew()
always_blocks = manager.gen_all_always_block()
return '\n'.join([stmt.verilog_serialize() for stmt in self.stmts])
return '\n'.join([stmt.verilog_serialize() for stmt in new_blocks.stmts]) + f'\n{always_blocks}' if self.stmts else ""
@dataclass(frozen=True)
class AlwaysBlock(Statement):
stmts: List[Statement]
clk: Expression = None
def serialize(self) -> str:
pass
def verilog_serialize(self) -> str:
cat_table: List[str] = []
for stmt in self.stmts:
cat_table.append(stmt.verilog_serialize())
if self.clk is None:
declares = '\n' + "\n".join(cat_table)
return deleblankline(f"always @(posedge clock) begin\n{indent(declares)}"+ "\nend")
else:
declares = '\n' + "\n".join(cat_table)
return deleblankline(f"always @(posedge {self.clk.verilog_serialize()}) begin\n{indent(declares)}" + "\nend")
# pass will change the "blocking" feature of Connect Stmt
@dataclass(frozen=False)
class Connect(Statement):
loc: Expression
@ -735,17 +650,20 @@ class Connect(Statement):
info: Info = NoInfo()
blocking: bool = True
bidirection: bool = False
mem: Dict = field(default_factory=dict)
def serialize(self) -> str:
if not self.bidirection:
return f'{self.info.serialize()}\n{self.loc.serialize()} <= {self.expr.serialize()}'
return f'{self.info.serialize()}{self.loc.serialize()} <= {self.expr.serialize()}'
else:
return f'{self.info.serialize()}\n{self.loc.serialize()} <= {self.expr.serialize()}\n' + \
f'{self.info.serialize()}\n{self.expr.serialize()} <= {self.loc.serialize()}'
def verilog_serialize(self) -> str:
op = "=" if self.blocking else "<="
return f'assign {self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()}{self.info.verilog_serialize()};'
if self.blocking is False:
return f'{self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()};\t{self.info.verilog_serialize()}'
return f'assign\t{self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()};\t{self.info.verilog_serialize()}'
# Verification
@ -802,8 +720,10 @@ class DefModule(FirrtlNode, ABC):
return f'{typ} {self.name} :{self.info.serialize()}{moduledeclares}\n'
def verilog_serializeHeader(self, typ: str) -> str:
moduledeclares = ' ' + indent(''.join([f'{p.verilog_serialize()}' for p in self.ports]))
return f'{typ} {self.name}(\n{deleblankline(moduledeclares)[:-1]}\n);\n'
port_declares: List[str] = []
for p in self.ports:
port_declares.append(indent("\n"+ p.verilog_serialize()))
return f"{typ} {self.name}(\t{self.info.verilog_serialize()}{''.join(port_declares)}\n);\n"
@dataclass(frozen=True)
@ -834,7 +754,7 @@ class ExtModule(DefModule):
return f'{self.serializeHeader("extmodule")}{s}'
def verilog_serialize(self) -> str:
return f'{self.verilog_serializeHeader("module")}endmodule\n'
return ""
@dataclass(frozen=True)
@ -844,9 +764,85 @@ class Circuit(FirrtlNode):
info: Info = NoInfo()
def serialize(self) -> str:
CheckCombLoop()
ms = '\n'.join([indent(f'\n{m.serialize()}') for m in self.modules])
return f'circuit {self.main} :{self.info.serialize()}{ms}\n'
def verilog_serialize(self) -> str:
self.requires()
ms = ''.join([f'{m.verilog_serialize()}\n' for m in self.modules])
return ms
def requires(self):
CheckCombLoop()
class CheckCombLoop:
connect_graph = DAG()
reg_map = {}
@staticmethod
def run(stmt: Statement):
def check_comb_loop_e(u: Expression, v: Expression):
if isinstance(u, (Reference, SubField, SubIndex, SubAccess)):
ux, vx = u.serialize(), v.serialize()
if ux in CheckCombLoop.reg_map or vx in CheckCombLoop.reg_map:
return
try:
CheckCombLoop.connect_graph.add_node_if_not_exists(vx)
CheckCombLoop.connect_graph.add_node_if_not_exists(ux)
CheckCombLoop.connect_graph.add_edge(ux, vx)
except TransformException as e:
raise e
elif isinstance(u, Mux):
check_comb_loop_e(u.tval, v)
check_comb_loop_e(u.fval, v)
elif isinstance(u, ValidIf):
check_comb_loop_e(u.value, v)
elif isinstance(u, DoPrim):
for arg in u.args:
check_comb_loop_e(arg, v)
else:
...
def check_comb_loop_s(s: Statement):
if isinstance(s, Connect):
if isinstance(s.loc, (Reference, SubField, SubIndex, SubAccess)):
check_comb_loop_e(s.expr, s.loc)
elif isinstance(s, DefNode):
check_comb_loop_e(s.value, Reference(s.name, s.value.typ))
elif isinstance(s, Conditionally):
check_comb_loop_s(s.conseq)
check_comb_loop_s(s.alt)
elif isinstance(s, EmptyStmt):
...
elif isinstance(s, Block):
for stmt in s.stmts:
check_comb_loop_s(stmt)
def get_reg_map(s: Statement):
if isinstance(s, Block):
for sx in s.stmts:
if isinstance(sx, DefRegister):
CheckCombLoop.reg_map[sx.name] = sx
elif isinstance(sx, DefMemPort):
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_data"] = DefWire(f"{sx.mem.name}_{sx.name}_data",
UIntType(IntWidth(sx.mem.typ.size)))
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_addr"] = DefWire(f"{sx.mem.name}_{sx.name}_addr",
UIntType(IntWidth(get_binary_width(sx.mem.typ.size))))
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_en"] = DefWire(f"{sx.mem.name}_{sx.name}_en",
UIntType(IntWidth(1)))
if sx.rw is False:
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_mask"] = DefWire(f"{sx.mem.name}_{sx.name}_mask",
UIntType(IntWidth(1)))
else:
...
elif isinstance(s, EmptyStmt):
...
elif isinstance(s, Conditionally):
get_reg_map(s.conseq)
get_reg_map(s.alt)
get_reg_map(stmt)
check_comb_loop_s(stmt)
return stmt

View File

@ -220,7 +220,7 @@ class Not(PrimOp):
return 'not'
def verilog_op(self):
return " ~ "
return "!"
# Bitwise And
@ -280,6 +280,9 @@ class Cat(PrimOp):
def __repr__(self):
return 'cat'
def verilog_op(self):
return self.__repr__()
# Bit Extraction
@dataclass(frozen=True, init=False)
@ -287,6 +290,9 @@ class Bits(PrimOp):
def __repr__(self):
return 'bits'
def verilog_op(self):
return self.__repr__()
# Head
@dataclass(frozen=True, init=False)

View File

@ -1,5 +1,9 @@
def indent(string: str) -> str:
return string.replace('\n', '\n ')
import math
from collections import OrderedDict, defaultdict
from copy import copy, deepcopy
def indent(string: str, space: int = 1) -> str:
return string.replace('\n', '\n' + '\t' * space)
def backspace(string: str) -> str:
@ -34,3 +38,130 @@ def auto_connect(ma, mb):
io_left <<= io_right
else:
io_right <<= io_left
def get_binary_width(target):
width = 1
while target / 2 >= 1:
width += 1
target = math.floor(target / 2)
return width
class TransformException(Exception):
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message
class DAG:
""" Directed acyclic graph implementation."""
def __init__(self):
self.graph = OrderedDict()
def add_node(self, name: str, graph = None):
if graph is None:
graph = self.graph
if name in graph:
...
else:
graph[name] = set()
def add_node_if_not_exists(self, name: str, graph = None):
try:
self.add_node(name, graph = graph)
except TransformException as e:
raise e
def delete_node(self, name: str, graph = None):
if graph is None:
graph = self.graph
if name not in graph:
raise TransformException(f'node {name} is not exists.')
graph.pop(name)
for _, edges in graph.items():
if name in edges:
edges.remove(name)
def delete_node_if_exists(self, name: str, graph = None):
try:
self.delete_node(name, graph = graph)
except TransformException as e:
raise e
def add_edge(self, ind_node, dep_node, graph = None):
if graph is None:
graph = self.graph
if ind_node not in graph or dep_node not in graph:
raise TransformException(f'nodes do not exist in graph.')
test_graph = deepcopy(graph)
test_graph[ind_node].add(dep_node)
is_valid, msg = self.validate(test_graph)
if is_valid:
graph[ind_node].add(dep_node)
else:
raise TransformException(f'Loop do exist in graph: {msg}')
def delete_edge(self, ind_node, dep_node, graph = None):
if graph is None:
graph = self.graph
if dep_node not in graph.get(ind_node, []):
raise TransformException(f'This edge does not exist in graph')
graph[ind_node].remove(dep_node)
def ind_nodes(self, graph = None):
if graph == None:
graph = self.graph
dep_nodes = set(
node for deps in graph.values() for node in deps
)
return [node for node in graph.keys() if node not in dep_nodes]
def topological_sort(self, graph = None):
if graph is None:
graph = self.graph
result = []
in_degree = defaultdict(lambda: 0)
for u in graph:
for v in graph[u]:
in_degree[v] += 1
ready = [node for node in graph if not in_degree[node]]
while ready:
u = ready.pop()
result.append(u)
for v in graph[u]:
in_degree[v] -= 1
if in_degree[v] == 0:
ready.append(v)
if len(result) == len(graph):
return result
else:
raise TransformException(f'graph is not acyclic.')
def validate(self, graph = None):
if graph is None:
graph = self.graph
if len(self.ind_nodes(graph)) == 0:
return False, 'no independent nodes detected.'
try:
self.topological_sort(graph)
except TransformException:
return False, 'graph is not acyclic.'
return True, 'valid'
def visit_graph(self, graph = None):
visited = []
if graph is None:
graph = self.graph
for v in graph:
for u in graph[v]:
visited.append(f'{v} -> {u}')
return visited
def size(self):
return len(self.graph)

0
pyhcl/passes/__init__.py Normal file
View File

37
pyhcl/passes/_pass.py Normal file
View File

@ -0,0 +1,37 @@
from typing import List
class Pass:
...
#Error handling
class PassException(Exception):
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message
class PassExceptions(Exception):
def __init__(self, exceptions: List[PassException]):
self.message = '\n'.join([str(exception) for exception in exceptions])
def __str__(self):
return '\n' + self.message
class Error:
def __init__(self):
self.errors: List[PassException] = []
def append(self, pe: PassException):
self.errors.append(pe)
def trigger(self):
if len(self.errors) == 0:
return
elif len(self.errors) == 1:
raise self.errors.pop()
else:
self.append(f'{len(self.errors)} errors detected!')
raise PassExceptions(self.errors)

View File

@ -0,0 +1,128 @@
from dataclasses import dataclass
from typing import Dict, List
from pyhcl.ir.low_ir import *
@dataclass
class AutoInferring:
max_width: int = 0
def run(self, c: Circuit):
modules: List[Module] = []
def auto_inferring_t(t: Type) -> Type:
if isinstance(t, UIntType):
if t.width.width == 0:
return UIntType(IntWidth(self.max_width))
else:
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
return t
elif isinstance(t, SIntType):
if t.width.width == 0:
return SIntType(IntWidth(self.max_width))
else:
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
return t
elif isinstance(t, (ClockType, ResetType, AsyncResetType)):
return t
elif isinstance(t, VectorType):
return VectorType(auto_inferring_t(t.typ), t.size)
elif isinstance(t, MemoryType):
return MemoryType(auto_inferring_t(t.typ), t.size)
elif isinstance(t, BundleType):
return BundleType([Field(fx.name, fx.flip, auto_inferring_t(fx.typ)) for fx in t.fields])
else:
return t
def auto_inferring_e(e: Expression, inferring_map: Dict[str, Type]) -> Expression:
if isinstance(e, Mux):
return Mux(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.tval, inferring_map),
auto_inferring_e(e.fval, inferring_map), auto_inferring_t(e.typ))
elif isinstance(e, ValidIf):
return ValidIf(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.value, inferring_map), auto_inferring_t(e.typ))
elif isinstance(e, DoPrim):
return DoPrim(e.op, [auto_inferring_e(arg, inferring_map) for arg in e.args], e.consts, auto_inferring_t(e.typ))
elif isinstance(e, UIntLiteral):
if e.width.width < get_binary_width(e.value):
return UIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
else:
return e
elif isinstance(e, SIntLiteral):
if e.width.width < get_binary_width(e.value) + 1:
return SIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
else:
return e
elif isinstance(e, Reference):
typ = inferring_map[e.name] if not isinstance(inferring_map[e.name], UnknownType) else auto_inferring_t(e.typ)
return Reference(e.name, typ)
elif isinstance(e, SubField):
expr = auto_inferring_e(e.expr, inferring_map)
typ = e.typ
for fx in expr.typ.fields:
if fx.name == e.name:
typ = fx.typ
return SubField(expr, e.name, typ)
elif isinstance(e, SubIndex):
expr = auto_inferring_e(e.expr, inferring_map)
return SubIndex(e.name, expr, e.value, expr.typ.typ)
elif isinstance(e, SubAccess):
expr = auto_inferring_e(e.expr, inferring_map)
index = auto_inferring_e(e.index, inferring_map)
return SubAccess(expr, index, expr.typ.typ)
else:
return e
def auto_inferring_s(s: Statement, inferring_map: Dict[str, Type]) -> Statement:
if isinstance(s, Block):
stmts: List[Statement] = []
for sx in s.stmts:
stmts.append(auto_inferring_s(sx, inferring_map))
return Block(stmts)
elif isinstance(s, Conditionally):
return Conditionally(auto_inferring_e(s.pred, inferring_map), auto_inferring_s(s.conseq, inferring_map), auto_inferring_s(s.alt, inferring_map), s.info)
elif isinstance(s, DefRegister):
clock = auto_inferring_e(s.clock, inferring_map)
reset = auto_inferring_e(s.reset, inferring_map)
init = auto_inferring_e(s.init, inferring_map)
typ = auto_inferring_t(s.typ)
inferring_map[s.name] = typ
return DefRegister(s.name, typ, clock, reset, init, s.info)
elif isinstance(s, DefWire):
inferring_map[s.name] = auto_inferring_t(s.typ)
return s
elif isinstance(s, DefMemory):
inferring_map[s.name] = auto_inferring_t(s.memType)
return s
elif isinstance(s, DefNode):
value = auto_inferring_e(s.value, inferring_map)
inferring_map[s.name] = value.typ
return DefNode(s.name, value, s.info)
elif isinstance(s, DefMemPort):
clk = auto_inferring_e(s.clk, inferring_map)
index = auto_inferring_e(s.index, inferring_map)
inferring_map[s.name] = UnknownType()
return DefMemPort(s.name, s.mem, index, clk, s.rw, s.info)
elif isinstance(s, DefInstance):
inferring_map[s.name] = UnknownType()
return s
elif isinstance(s, Connect):
return Connect(auto_inferring_e(s.loc, inferring_map), auto_inferring_e(s.expr, inferring_map), s.info, s.blocking, s.bidirection, s.mem)
else:
return s
def auto_inferring_m(m: DefModule, inferring_map: Dict[str, Type]) -> DefModule:
if isinstance(m, Module):
ports: List[Port] = []
for p in m.ports:
inferring_map[p.name] = auto_inferring_t(p.typ)
ports.append(Port(p.name, p.direction, inferring_map[p.name], p.info))
body = auto_inferring_s(m.body, inferring_map)
return Module(m.name, ports, body, m.typ, m.info)
else:
return m
for m in c.modules:
inferring_map: Dict[str, Type] = {}
modules.append(auto_inferring_m(m, inferring_map))
return Circuit(modules, c.main, c.info)

106
pyhcl/passes/check_flows.py Normal file
View File

@ -0,0 +1,106 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass, PassException, Error
from pyhcl.passes.utils import times_f_f, times_g_flip, to_flow
from pyhcl.passes.wir import *
class WrongFlow(PassException):
def __init__(self, info: Info, mname: str, expr: str, wrong: Flow, right: Flow):
super().__init__(f'{info}: [module {mname}] Expression {expr} is used as a {wrong} but can only be used as a {right}.')
class CheckFlow(Pass):
def run(self, c: Circuit):
errors = Error()
def get_flow(e: Expression, flows: Dict[str, Flow]) -> Flow:
if isinstance(e, Reference):
return flows[e.name]
elif isinstance(e, SubIndex):
return get_flow(e.expr, flows)
elif isinstance(e, SubAccess):
return get_flow(e.expr, flows)
elif isinstance(e, SubField):
if isinstance(e.expr.typ, BundleType):
for f in e.expr.typ.fields:
if f.name == e.name:
return times_g_flip(get_flow(e.expr, flows), f.flip)
return SourceFlow()
def flip_q(t: Type) -> bool:
def flip_rec(t: Type, f: Orientation) -> bool:
if isinstance(t, BundleType):
final = True
for field in t.fields:
final = flip_rec(field.typ, times_f_f(f, field.flip)) and final
return final
elif isinstance(t, VectorType):
return flip_rec(t.typ, f)
else:
return isinstance(f, Flip)
return flip_rec(t, Default())
def check_flow(info: Info, mname: str, flows: Dict[str, Flow], desired: Flow, e: Expression):
flow = get_flow(e, flows)
if isinstance(flow, SourceFlow) and isinstance(desired, SinkFlow):
errors.append(WrongFlow(info, mname, e.serialize(), desired, flow))
def check_flow_e(info: Info, mname: str, flows: Dict[str, Flow], e: Expression):
if isinstance(e, Mux):
for _, ee in e.__dict__.items():
if isinstance(ee, Expression):
check_flow(info, mname, flows, SourceFlow(), ee)
if isinstance(e, DoPrim):
for ee in e.args:
if isinstance(ee, Expression):
check_flow(info, mname, flows, SourceFlow(), ee)
for _, ee in e.__dict__.items():
if isinstance(ee, Expression):
check_flow_e(info, mname, flows, ee)
def check_flow_s(minfo: Info, mname: str, flows: Dict[str, Flow], s: Statement):
info = lambda s: minfo if isinstance(s, NoInfo) else s.info
if isinstance(s, DefWire):
flows[s.name] = DuplexFlow()
elif isinstance(s, DefRegister):
flows[s.name] = DuplexFlow()
elif isinstance(s, DefMemory):
flows[s.name] = SourceFlow()
elif isinstance(s, DefInstance):
flows[s.name] = SourceFlow()
elif isinstance(s, DefNode):
check_flow(info, mname, flows, SourceFlow(), s.value)
flows[s.name] = SourceFlow()
elif isinstance(s, DefMemPort):
flows[s.name] = SinkFlow()
elif isinstance(s, Connect):
check_flow(info, mname, flows, SinkFlow(), s.loc)
check_flow(info, mname, flows, SourceFlow(), s.expr)
...
elif isinstance(s, Conditionally):
check_flow(info, mname, flows, SourceFlow(), s.pred)
else:
...
for _, ss in s.__dict__.items():
if isinstance(ss, Expression):
check_flow_e(info, mname, flows, ss)
if isinstance(ss, Statement):
check_flow_s(minfo, mname, flows, ss)
for m in c.modules:
flows: Dict[str, Flow] = {}
if hasattr(m, 'ports') and isinstance(m.ports, list):
for p in m.ports:
flows[p.name] = to_flow(p.direction)
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
for stmt in m.body.stmts:
check_flow_s(m.info, m.name, flows, stmt)
errors.trigger()
return c

391
pyhcl/passes/check_form.py Normal file
View File

@ -0,0 +1,391 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass, PassException, Error
from pyhcl.passes.utils import ModuleGraph, to_flow, flow, create_exps, get_info, has_flip
from pyhcl.passes.wir import *
# ScopeView
class ScopeView:
def __init__(self, moduleNS: set, scopes: List[set]):
self.moduleNS = moduleNS
self.scopes = scopes
def declare(self, name: str):
self.moduleNS.add(name)
self.scopes[0].add(name)
# ensures that the name cannot be used again, but prevent references to this name
def add_to_namespace(self, name: str):
self.moduleNS.add(name)
def expand_m_port_visibility(self, port: DefMemPort):
mem_in_scopes = False
def expand_m_port(scope: set, mp: DefMemPort):
if mp.mem.name in scope:
scope.add(mp.name)
return scope
self.scopes = list(map(lambda scope: expand_m_port(scope, port), self.scopes))
for sx in self.scopes:
if port.mem.name in sx:
mem_in_scopes = True
if mem_in_scopes is False:
self.scopes[0].add(port.name)
def legal_decl(self, name: str) -> bool:
return name in self.moduleNS
def legal_ref(self, name: str) -> bool:
for s in self.scopes:
if name in s:
return True
return False
def get_ref(self):
for s in self.scopes:
print(s)
def child_scope(self):
return ScopeView(self.moduleNS, [set()])
def scope_view():
return ScopeView(set(), [set()])
# Custom Exceptions
class NotUniqueException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Reference {name} does not have a unique name.')
class InvalidLOCException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Invalid connect to an expression that is not a reference or a WritePort.')
class NegUIntException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] UIntLiteral cannot be negative.')
class UndecleardReferenceException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Reference {name} is not declared.')
class PoisonWithFlipException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Poison {name} cannot be a bundle type with flips.')
class MemWithFlipException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Memory {name} cannot be a bundle type with flips.')
class IllegalMemLatencyException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Memory {name} must have non-negative read latency and positive write latency.')
class RegWithFlipException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Register {name} cannot be a bundle type with flips.')
class InvalidAccessException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Invalid access to non-reference.')
class ModuleNameNotUniqueException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: Repeat definition of module {mname}')
class DefnameConflictException(PassException):
def __init__(self, info: Info, mname: str, defname: str):
super().__init__(f'{info}: defname {defname} of extmodule {mname} conflicts with an existing module.')
class DefnameDifferentPortsException(PassException):
def __init__(self, info: Info, mname: str, defname: str):
super().__init__(f'{info}: ports of extmodule {mname} with defname {defname} are different for an extmodule with the same defname.')
class DefnameDifferentPortsException(PassException):
def __init__(self, info: Info, name: str):
super().__init__(f'{info}: Module {name} is not defined.')
class IncorrectNumArgsException(PassException):
def __init__(self, info: Info, mname: str, op: str, n: int):
super().__init__(f'{info}: [module {mname}] Primop {op} requires {n} expression arguments.')
class IncorrectNumConstsException(PassException):
def __init__(self, info: Info, mname: str, op: str, n: int):
super().__init__(f'{info}: [module {mname}] Primop {op} requires {n} integer arguments.')
class NegWidthException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Width cannot be negative.')
class NegVecSizeException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Vector type size cannot be negative.')
class NegMemSizeException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Memory size cannot be negative or zero.')
class InstanceLoop(PassException):
def __init__(self, info: Info, mname: str, loop: str):
super().__init__(f'{info}: [module {mname}] Has instance loop {loop}.')
class NoTopModuleException(PassException):
def __init__(self, info: Info, name: str):
super().__init__(f'{info}: A single module must be named {name}.')
class NegArgException(PassException):
def __init__(self, info: Info, mname: str, op: str, value: int):
super().__init__(f'{info}: [module {mname}] Primop {op} argument {value} < 0.')
class LsbLargerThanMsbException(PassException):
def __init__(self, info: Info, mname: str, op: str, lsb: int, msb: int):
super().__init__(f'{info}: [module {mname}] Primop {op} lsb {lsb} > {msb}.')
class ResetInputException(PassException):
def __init__(self, info: Info, mname: str, expr: Expression):
super().__init__(f'{info}: [module {mname}] Abstract Reset not allowed as top-level input: {expr.serialize()}')
class ResetExtModuleOutputException(PassException):
def __init__(self, info: Info, mname: str, expr: Expression):
super().__init__(f'{info}: [module {mname}] Abstract Reset not allowed as ExtModule output: {expr.serialize()}')
class ModuleNotDefinedException(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: Module {name} is not defined.')
class CircuitHasNoModules(PassException):
def __init__(self, info: Info, cname: str):
super().__init__(f'{info}: Circuit {cname} has no modules.')
class ModuleHasNoPorts(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: Module {mname} has no ports.')
class ModuleHasNoBody(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: Module {mname} has no body.')
class CheckHighForm(Pass):
def __init__(self, c: Circuit):
self.c: Circuit = c
self.ms: List[DefModule] = c.modules
self.module_names: List[str] = [_.name for _ in c.modules]
self.int_module_name: List[str] = [_.name for _ in c.modules if isinstance(_, Module)]
self.errors: Error = Error()
def check_unique_module_name(self):
for idx in range(len(self.module_names)):
if self.int_module_name[idx] in self.int_module_name[idx:]:
m = self.ms[idx]
self.errors.append(ModuleNameNotUniqueException(m.info, m.name))
def check_extmodule(self):
for m in self.ms:
if isinstance(m, ExtModule) and m.name in self.int_module_name:
self.errors.append(DefnameConflictException(m.info, m.name, m.defname))
def strip_width(self, typ: Type) -> Type:
if isinstance(typ, GroundType):
return typ.map_width(UnknownWidth)
elif isinstance(typ, AggregateType):
return typ.map_type(self.strip_width())
def check_highForm_primOp(self, info: Info, mname: str, e: DoPrim):
def correct_num(ne, nc):
if isinstance(ne, int) and len(e.args) != ne:
self.errors.append(IncorrectNumArgsException(info, mname, e.op.serialize(), ne))
if len(e.consts) != nc:
self.errors.append(IncorrectNumConstsException(info, mname, e.op.serialize(), nc))
def non_negative_consts():
for _ in [c for c in e.consts if c < 0]:
self.errors.append(NegArgException(info, mname, e.op.serialize(), _))
if isinstance(e.op, (Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Dshl, Dshr, And, Or, Xor, Cat)):
correct_num(2, 0)
elif isinstance(e.op, (AsUInt, AsSInt, AsClock, Cvt, Neq, Not)):
correct_num(1, 0)
elif isinstance(e.op, AsFixedPoint):
correct_num(1, 1)
elif isinstance(e.op, (Shl, Shr, Pad, Head, Tail)):
correct_num(1, 1)
non_negative_consts()
elif isinstance(e.op, Bits):
correct_num(1, 2)
non_negative_consts()
if len(e.consts) == 2:
msb, lsb = e.consts[0], e.consts[1]
if msb < lsb:
self.errors.append(LsbLargerThanMsbException(info, mname, e.op.serialize(), lsb, msb))
elif isinstance(e.op, (Andr, Orr, Xorr, Neg)):
correct_num(1, 0)
def check_valid_loc(self, info: Info, mname: str, e: Expression):
if isinstance(e, (UIntLiteral, SIntLiteral, DoPrim)):
self.errors.append(InvalidLOCException(info, mname))
def check_instance(self, info: Info, child: str, parent: str):
if child not in self.module_names:
self.errors.append(ModuleNotDefinedException(info, parent, child))
childToParent = ModuleGraph().add(parent, child)
if childToParent is not None and len(childToParent) > 0:
self.errors.append(InstanceLoop(info, parent, "->".join(childToParent)))
def check_high_form_w(self, info: Info, mname: str, w: Width):
if isinstance(w, IntWidth) and w.width < 0:
self.errors.append(NegWidthException(info, mname))
def check_high_form_t(self, info: Info, mname: str, typ: Type):
t_attr = typ.__dict__.items()
for _, ta in t_attr:
if isinstance(ta, Type):
self.check_high_form_t(info, mname, ta)
if isinstance(ta, Width):
self.check_high_form_w(info, mname, ta)
if isinstance(typ, VectorType) and typ.size < 0:
self.errors.append(NegVecSizeException(info, mname))
def valid_sub_exp(self, info: Info, mname: str, e: Expression):
if isinstance(e, (Reference, SubField, SubIndex, SubAccess)):
...
elif isinstance(e, (Mux, ValidIf)):
...
else:
self.errors.append(InvalidAccessException(info, mname))
def check_high_form_e(self, info: Info, mname: str, names: ScopeView, e: Expression):
e_attr = e.__dict__.items()
if isinstance(e, Reference) and names.legal_ref(e.name) is False:
self.errors.append(UndecleardReferenceException(info, mname, e.name))
...
elif isinstance(e, UIntLiteral) and e.value < 0:
self.errors.append(NegUIntException(info, mname, e.name))
elif isinstance(e, DoPrim):
self.check_highForm_primOp(info, mname, e)
elif isinstance(e, (Reference, UIntLiteral, Mux, ValidIf)):
...
elif isinstance(e, SubAccess):
self.valid_sub_exp(info, mname, e.expr)
else:
for _, ea in e_attr:
if isinstance(ea, Expression):
self.valid_sub_exp(info, mname, ea)
for _, ea in e_attr:
if isinstance(ea, Width):
self.check_high_form_w(info, mname + '/' + e.serialize(), ea)
if isinstance(ea, Expression):
self.check_high_form_e(info, mname, names, ea)
def check_name(self, info: Info, mname: str, names: ScopeView, referenced: bool, s: Statement):
if referenced is False:
return
if len(s.name) == 0:
assert referenced is False, 'A statement with an empty name cannot be used as a reference!'
else:
if names.legal_decl(s.name) is True:
self.errors.append(NotUniqueException(info, mname, s.name))
if referenced:
names.declare(s.name)
else:
names.add_to_namespace(s.name)
def check_high_form_s(self, minfo: Info, mname: str, names: ScopeView, s: Statement):
s_attr = s.__dict__.items()
t_info = get_info(s)
info = t_info if isinstance(t_info, NoInfo) is False else minfo
referenced = True if isinstance(s, (DefWire, DefRegister, DefInstance, DefMemory, DefNode, Port)) else False
self.check_name(info, mname, names, referenced, s)
if isinstance(s, DefRegister):
if has_flip(s.typ):
self.errors.append(RegWithFlipException(info, mname, s.name))
elif isinstance(s, DefMemory):
if has_flip(s.memType.typ):
self.errors.append(MemWithFlipException(info, mname, s.name))
if s.memType.size < 0:
self.errors.append(NegMemSizeException(info, mname))
elif isinstance(s, DefInstance):
self.check_instance(info, mname, s.module)
elif isinstance(s, Connect):
self.check_valid_loc(info, mname, s.loc)
elif isinstance(s, DefMemPort):
names.expand_m_port_visibility(s)
elif isinstance(s, Block):
for stmt in s.stmts:
self.check_high_form_s(info, mname, names, stmt)
else:
...
for _, sa in s_attr:
if isinstance(sa, Type):
self.check_high_form_t(info, mname, sa)
elif isinstance(sa, Expression):
self.check_high_form_e(info, mname, names, sa)
if isinstance(s, Conditionally):
self.check_high_form_s(minfo, mname, names, s.conseq)
self.check_high_form_s(minfo, mname, names, s.alt)
else:
for _, sa in s_attr:
if isinstance(sa, Statement):
self.check_high_form_s(minfo, mname, names, sa)
def check_high_form_p(self, mname: str, names: ScopeView, p: Port):
if names.legal_decl(p.name) is True:
self.errors.append(NotUniqueException(NoInfo, mname, p.name))
names.declare(p.name)
self.check_high_form_t(p.info, mname, p.typ)
def find_bad_reset_type_ports(self, m: DefModule, dir: Direction):
bad_reset_type_ports = []
bad = to_flow(dir)
gen = ((create_exps(ref), p1) for (ref, p1) in [(Reference(p.name, p.typ), p) for p in m.ports])
for expr, port in gen:
if isinstance(expr, list):
for exx in expr:
if exx is not None and exx.typ == ResetType and flow(exx) == bad:
bad_reset_type_ports.append((port, exx))
else:
if expr is not None and expr.typ == ResetType and flow(expr) == bad:
bad_reset_type_ports.append((port, expr))
return bad_reset_type_ports
def check_high_form_m(self, m: DefModule):
names = scope_view()
if hasattr(m, 'ports') and isinstance(m.ports, list):
for p in m.ports:
self.check_high_form_p(m.name, names, p)
else:
self.errors.append(ModuleHasNoPorts(m.info, m.name))
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
for s in m.body.stmts:
self.check_high_form_s(m.info, m.name, names, s)
else:
self.errors.append(ModuleHasNoBody(m.info, m.name))
if isinstance(m, ExtModule):
for port, expr in self.find_bad_reset_type_ports(m, Output):
self.errors.append(ResetExtModuleOutputException(port.info, m.name, expr))
def run(self):
if hasattr(self.c, 'modules') and isinstance(self.c.modules, list):
for m in self.c.modules:
self.check_high_form_m(m)
else:
self.errors.append(CircuitHasNoModules(self.c.info, self.c.main))
if self.c.main not in self.int_module_name:
self.errors.append(NoTopModuleException(self.c.info, self.c.main))
self.errors.trigger()
return self.c

288
pyhcl/passes/check_types.py Normal file
View File

@ -0,0 +1,288 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass, PassException, Error
from pyhcl.passes.utils import times_f_f
from pyhcl.passes.wir import WrappedType
# Custom Exceptions
class SubfieldNotInBundle(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Subfield {name} is not in bundle.')
class SubfieldOnNonBundle(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Subfield {name} is accessed on non-bundle.')
class IndexTooLarge(PassException):
def __init__(self, info: Info, mname: str, value: int):
super().__init__(f'{info}: [module {mname}] Index with value {value} is too large.')
class IndexOnNonVector(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Index illegal on non-vector type.')
class AccessIndexNotUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Access index must be a UInt type')
class IndexNotUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Index is not of UIntType.')
class EnableNotUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Enable is not of UIntType.')
class InvalidConnect(PassException):
def __init__(self, info: Info, mname: str, cond: str, lhs: Expression, rhs: Expression):
ltyp = f'\t{lhs.serialize()}: {lhs.typ.serialize()}'
rtyp = f'\t{rhs.serialize()}: {rhs.typ.serialize()}'
super().__init__(f'{info}: [module {mname}] Type mismatch in \'{cond}\'.\n{ltyp}\n{rtyp}')
class ReqClk(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Requires a clock typed signal.')
class RegReqClk(PassException):
def __init__(self, info: Info, mname: str, name: str):
super().__init__(f'{info}: [module {mname}] Register {name} requires a clock typed signal.')
class EnNotUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Enable must be a 1-bit UIntType typed signal.')
class PredNotUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Predicate not a 1-bit UIntType.')
class OpNotGround(PassException):
def __init__(self, info: Info, mname: str, op: str):
super().__init__(f'{info}: [module {mname}] Primop {op} cannot operate on non-ground types.')
class OpNotUInt(PassException):
def __init__(self, info: Info, mname: str, op: str, e: str):
super().__init__(f'{info}: [module {mname}] Primop {op} requires argument {e} to be a UInt type.')
class InvalidRegInit(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Type of init must match type of DefRegister.')
class OpNotAllUInt(PassException):
def __init__(self, info: Info, mname: str, op: str):
super().__init__(f'{info}: [module {mname}] Primop {op} requires all arguments to be UInt type.')
class OpNotAllSameType(PassException):
def __init__(self, info: Info, mname: str, op: str):
super().__init__(f'{info}: [module {mname}] Primop {op} requires all operands to have the same type.')
class OpNotCorrectType(PassException):
def __init__(self, info: Info, mname: str, op: str, typs: List[str]):
super().__init__(f'{info}: [module {mname}] Primop {op} does not have correct arg types: {typs}.')
class NodePassiveType(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Node must be a passive type.')
class MuxSameType(PassException):
def __init__(self, info: Info, mname: str, t1: str, t2: str):
super().__init__(f'{info}: [module {mname}] Must mux between equivalent types: {t1} != {t2}.')
class MuxPassiveTypes(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Must mux between passive types.')
class MuxCondUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] A mux condition must be of type 1-bit UInt.')
class MuxClock(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Firrtl does not support muxing clocks.')
class ValidIfPassiveTypes(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] Must validif a passive type.')
class ValidIfCondUInt(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [module {mname}] A validif condition must be of type UInt.')
class IllegalResetType(PassException):
def __init__(self, info: Info, mname: str, exp: str):
super().__init__(f'{info}: [module {mname}] Register resets must have type Reset, AsyncReset, or UInt<1>: {exp}.')
class IllegalUnknownType(PassException):
def __init__(self, info: Info, mname: str, exp: str):
super().__init__(f'{info}: [module {mname}] Uninferred type: {exp}.')
# TODO PrintfArgNotGround | OpNoMixFix | OpNotAnalog | IllegalAnalogDeclaration | IllegalAttachExp
class CheckTypes(Pass):
def legal_reset_type(self, typ: Type) -> bool:
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
return typ.width.width == 1
elif isinstance(typ, AsyncResetType):
return True
elif isinstance(typ, ResetType):
return True
else:
return False
def legal_cond_type(self, typ: Type) -> bool:
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
return typ.width.width == 1
elif isinstance(typ, UIntType):
return True
else:
return False
def bulk_equals(self, t1: Type, t2: Type, flip1: Orientation, flip2: Orientation) -> bool:
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
return flip1 == flip2
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
return flip1 == flip2
elif isinstance(t1, SIntType) and isinstance(t2, SIntType):
return flip1 == flip2
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
return flip1 == flip2
elif isinstance(t1, ResetType):
return self.legal_reset_type(t2) and flip1 == flip2
elif isinstance(t2, ResetType):
return self.legal_reset_type(t1) and flip1 == flip2
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
t1_fields = {}
for f1 in t1.fields:
t1_fields[f1.name] = (f1.typ, f1.flip)
for f2 in f2.fields:
if f2.name in t1_fields.keys():
t1_flip = t1_fields[f2.name].typ
return self.bulk_equals(t1_flip, f2.typ, times_f_f(flip1, t1_flip) ,times_f_f(flip2, f2.flip))
else:
return True
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
return self.bulk_equals(t1.typ, t2.typ, flip1, flip2)
else:
return False
@staticmethod
def valid_connect(locTyp: Type, exprTyp: Type) -> bool:
if isinstance(locTyp, (ClockType, UIntType)) and isinstance(exprTyp, (ClockType, UIntType)):
return True
return type(locTyp) == type(exprTyp)
def valid_connects(self, c: Connect) -> bool:
return CheckTypes.valid_connect(c.loc.typ, c.expr.typ)
def run(self, c: Circuit):
errors = Error()
def passive(t: Type) -> bool:
if isinstance(t, (UIntType, SIntType)):
return True
elif isinstance(t, VectorType):
return passive(t.typ)
elif isinstance(t, BundleType):
final = True
for f in t.fields:
final = f.flip == Default and passive(f.typ) and final
return final
else:
return True
def check_typs_primop(info: Info, mname: str, e: DoPrim):
def check_all_typs(exprs: List[Expression], okUInt: bool, okSInt: bool, okClock: bool, okAsync: bool):
for expr in exprs:
if isinstance(expr.typ, UIntType) and okUInt is False:
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
elif isinstance(expr.typ, UIntType) and okSInt is False:
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
elif isinstance(expr.typ, ClockType) and okClock is False:
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
elif isinstance(expr.typ, AsyncResetType) and okAsync is False:
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
if isinstance(e.op, (AsUInt, AsSInt, AsClock, AsyncResetType)):
...
elif isinstance(e.op, (Dshl, Dshr)):
check_all_typs(list(e.args[0]), True, True, False, False)
check_all_typs(e.args[1:], True, False, False, False)
elif isinstance(e.op, (Add, Sub, Mul, Lt, Leq, Gt, Geq, Eq, Neq)):
check_all_typs(e.args, True, True, False, False)
elif isinstance(e.op, (Pad, Bits, Head, Tail)):
check_all_typs(e.args, True, True, False, False)
elif isinstance(e.op, (Shr, Shl, Cat)):
check_all_typs(e.args, True, True, False, False)
else:
check_all_typs(e.args, True, True, False, False)
def check_types_e(info: Info, mname: str, e: Expression):
if isinstance(e, DoPrim):
check_typs_primop(info, mname, e)
elif isinstance(e, Mux):
if WrappedType(e.tval.typ) != WrappedType(e.fval.typ):
errors.append(MuxSameType(info, mname, e.tval.typ.serialize(), e.fval.typ.serialize()))
if passive(e.typ) is False:
errors.append(MuxPassiveTypes(info, mname))
if self.legal_cond_type(e.cond.typ) is False:
errors.append(MuxCondUInt(info, mname))
elif isinstance(e, ValidIf):
if passive(e.typ) is False:
errors.append(ValidIfPassiveTypes(info, mname))
if isinstance(e.cond.typ, UIntType):
...
else:
errors.append(ValidIfCondUInt(info, mname))
else:
...
for _, ee in e.__dict__.items():
if isinstance(ee, Expression):
check_types_e(info, mname, ee)
def check_types_s(minfo: Info, mname: str, s: Statement):
def get_info(s):
if isinstance(s, NoInfo):
return minfo
else:
return s
if isinstance(s, Connect) and self.valid_connects(s) is False:
con_msg = Connect(s.loc, s.expr, NoInfo()).serialize()
errors.append(InvalidConnect(get_info(s), mname, con_msg, s.loc, s.expr))
elif isinstance(s, DefRegister):
if isinstance(s.init, Expression) and WrappedType(s.typ) != WrappedType(s.init.typ):
errors.append(InvalidRegInit(get_info(s), mname))
if isinstance(s.init, Expression) and CheckTypes.valid_connect(s.typ, s.init.typ) is False:
con_msg = Connect(s.loc, s.expr, NoInfo()).serialize()
errors.append(InvalidConnect(get_info(s), mname, con_msg, Reference(s.name, s.typ), s.init))
if isinstance(s.init, Expression) and self.legal_reset_type(s.reset.typ) is False:
errors.append(IllegalResetType(get_info(s), mname, s.name))
if not isinstance(s.clock.typ, ClockType) or not isinstance(s.clock.typ.width, IntWidth) or s.clock.typ.width.width != 1:
errors.append(RegReqClk(get_info(s), mname, s.name))
elif isinstance(s, Conditionally) and self.legal_cond_type(s.pred.typ) is False:
errors.append(PredNotUInt(get_info(s), mname))
elif isinstance(s, DefNode):
if passive(s.value.typ) is False:
errors.append(NodePassiveType(get_info(s), mname))
elif isinstance(s, DefMemory):
...
else:
...
for _, ss in s.__dict__.items():
if isinstance(ss, Statement):
check_types_s(get_info(s), mname, ss)
if isinstance(ss, Expression):
check_types_e(get_info(s), mname, ss)
for m in c.modules:
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
for s in m.body.stmts:
check_types_s(m.info, m.name, s)
errors.trigger()
return c

View File

@ -0,0 +1,181 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass, PassException, Error
from pyhcl.passes.utils import get_width, get_info
from pyhcl.passes.check_types import IllegalResetType, CheckTypes, InvalidConnect
# MaxWidth
maxWidth = 1000000
class UninferredWidth(PassException):
def __init__(self, info: Info, target: str):
super().__init__(f'{info}: Uninferred width for target below. (Did you forget to assign to it?) \n{target}')
class InvalidRange(PassException):
def __init__(self, info: Info, target: str, i: Type):
super().__init__(f'{info}: Invalid range {i.serialize()} for target below. (Are the bounds valid?) \n{target}')
class WidthTooSmall(PassException):
def __init__(self, info: Info, mname: str, b: int):
super().__init__(f'{info} : [target {mname}] Width too small for constant {b}.')
class WidthTooBig(PassException):
def __init__(self, info: Info, mname: str, b: int):
super().__init__(f'{info} : [target ${mname}] Width {b} greater than max allowed width of {maxWidth} bits')
class DshlTooBig(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Width of dshl shift amount must be less than {maxWidth} bits.')
class MultiBitAsClock(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to a Clock.')
class MultiBitAsAsyncReset(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to an AsyncReset.')
class NegWidthException(PassException):
def __init__(self, info: Info, mname: str):
super().__init__(f'{info}: [target {mname}] Width cannot be negative or zero.')
class BitsWidthException(PassException):
def __init__(self, info: Info, mname: str, hi: int, width: int, exp: str):
super().__init__(f'{info}: [target {mname}] High bit {hi} in bits operator is larger than input width {width} in {exp}.')
class HeadWidthException(PassException):
def __init__(self, info: Info, mname: str, n: int, width: int):
super().__init__(f'{info}: [target {mname}] Parameter {n} in head operator is larger than input width {width}.')
class TailWidthException(PassException):
def __init__(self, info: Info, mname: str, n: int, width: int):
super().__init__(f'{info}: [target {mname}] Parameter {n} in tail operator is larger than input width {width}.')
class CheckWidths(Pass):
def run(self, c: Circuit):
errors = Error()
def gen_target(name: str, subname: str) -> str:
return f'{name}-{subname}'
def check_width_w(info: Info, target: str, t: Type, w: Width):
if isinstance(w, IntWidth) and w.width >= maxWidth:
errors.append(WidthTooBig(info, target, w.width))
elif isinstance(w, IntWidth):
...
else:
errors.append(UninferredWidth(info, target))
def has_width(typ: Type) -> bool:
if isinstance(typ, GroundType) and hasattr(typ, 'width') and isinstance(typ.width, IntWidth):
return True
elif isinstance(typ, GroundType):
return False
else:
raise PassException(f'hasWidth - {typ}')
def check_width_t(info: Info, target: str, t: Type):
if isinstance(t, BundleType):
for f in t.fields:
check_width_f(info, target, f)
else:
for _, tt in t.__dict__.items():
if isinstance(tt, Type):
check_width_t(info, target, tt)
for _, tt in t.__dict__.items():
if isinstance(tt, Width):
check_width_w(info, target, t, tt)
def check_width_f(info: Info, target: str, f: Field):
check_width_t(info, target, f.typ)
def check_width_e_leaf(info: Info, target: str, expr: Expression):
if isinstance(expr, UIntLiteral) and get_binary_width(expr.value) > get_width(expr.width):
errors.append(WidthTooSmall(info, target, expr.value))
elif isinstance(expr, SIntLiteral) and get_binary_width(expr.value) + 1 > get_width(expr.width):
errors.append(WidthTooSmall(info, target, expr.value))
elif isinstance(expr, DoPrim) and len(expr.args) == 2:
if isinstance(expr.op, Dshl) and has_width(expr.args[0].typ) and get_width(expr.args[1].typ.width) > maxWidth:
errors.append(DshlTooBig(info, target))
elif isinstance(expr, DoPrim) and len(expr.args) == 1:
if isinstance(expr.op, Bits) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.consts[0]:
errors.append(BitsWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width), expr.serialize()))
elif isinstance(expr.op, Head) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
errors.append(HeadWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
elif isinstance(expr.op, Tail) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
errors.append(TailWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
elif isinstance(expr.op, AsClock) and expr.consts[0] != 1:
errors.append(MultiBitAsClock(info, target))
def check_width_e(info: Info, target: str, rec_depth: int, e: Expression):
check_width_e_leaf(info, target, e)
if isinstance(e, (Mux, ValidIf, DoPrim)):
if rec_depth > 0:
for _, ee in e.__dict__.items():
if isinstance(ee, Expression):
check_width_e(info, target, rec_depth - 1, ee)
else:
check_width_e_dfs(info, target, e)
def check_width_e_dfs(info: Info, target: str, expr: Expression):
stack = expr.__dict__.items()
def push(e: Expression):
stack.append(e)
while len(stack) > 0:
current = stack
check_width_e_leaf(info, target, current)
for _, leaf in current.__dict__.items():
if isinstance(leaf, Expression):
push(leaf)
def check_width_s(minfo: Info, target: str, s: Statement):
info = get_info(s)
if isinstance(info, NoInfo):
info = minfo
for _, ss in s.__dict__.items():
if isinstance(ss, Expression):
check_width_e(info, target, 4, ss)
if isinstance(ss, Statement):
check_width_s(info, target, ss)
if isinstance(ss, Type):
check_width_t(info, target, ss)
if isinstance(s, DefRegister):
sx = s.reset.typ if isinstance(s.reset, Expression) else None
if sx is None:
...
elif isinstance(sx, UIntType) and get_width(sx.width) == 1:
...
elif isinstance(sx, AsyncResetType):
...
elif isinstance(sx, ResetType):
...
else:
errors.append(IllegalResetType(info, target, s.name))
if isinstance(s.init, Expression) and CheckTypes.valid_connect(s.typ, s.init.typ) is False:
con_msg = DefRegister(s.name, s.typ, s.clock, s.reset, s.init, NoInfo())
errors.append(InvalidConnect(info, target, con_msg, _, s.init))
def check_width_p(minfo: Info, target: str, p: Port):
check_width_t(p.info, target, p.typ)
def check_width_m(target: str, m: DefModule):
if hasattr(m, 'ports') and isinstance(m.ports, list):
for p in m.ports:
check_width_p(m.info, gen_target(target, m.name), p)
if hasattr(m, 'body') and isinstance(m.body, Block):
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
for s in m.body.stmts:
check_width_s(m.info, gen_target(target, m.name), s)
for m in c.modules:
check_width_m(c.main, m)
errors.trigger()
return c

View File

@ -0,0 +1,150 @@
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
@dataclass
class ExpandAggregate(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[DefModule] = []
def flip_direction(d: Direction) -> Direction:
if isinstance(d, Output):
return Input()
else:
return Output()
def flatten_vector(name: str, t: Type) -> list:
decs = []
if isinstance(t, VectorType):
for nx, tx in [(f"{name}_{i}", t.typ) for i in range(t.size)]:
if isinstance(tx, VectorType):
decs = decs + flatten_vector(nx, tx)
elif isinstance(tx, BundleType):
decs = decs + [(nxx, txx) for nxx, _, txx in flatten_bundle(nx, tx)]
else:
decs.append((nx, tx))
return decs
def flatten_bundle(name: str, t: Type) -> list:
decs = []
if isinstance(t, BundleType):
for nx, fx, tx in [(f"{name}_{f.name}", f.flip, f.typ) for f in t.fields]:
if isinstance(tx, BundleType):
decs = decs + flatten_bundle(nx, tx)
elif isinstance(tx, VectorType):
decs = decs + [(nxx, fx, txx) for nxx, txx in flatten_vector(nx, tx)]
else:
decs.append((nx, fx, tx))
return decs
def expand_aggregate_wire(stmt: Statement, stmts: List[Statement]):
if isinstance(stmt.typ, VectorType):
typs = flatten_vector(stmt.name, stmt.typ)
for nx, tx in typs:
stmts.append(DefWire(nx, tx, stmt.info))
elif isinstance(stmt.typ, BundleType):
typs = flatten_bundle(stmt.name, stmt.typ)
for nx, _, tx in typs:
stmts.append(DefWire(nx, tx, stmt.info))
else:
stmts.append(stmt)
def expand_aggregate_node(stmt: Statement, stmts: List[Statement]):
value = stmt.value
if isinstance(value.typ, VectorType):
if isinstance(value, Mux):
tval, fval = value.tval, value.fval
tval_typs, fval_typs = flatten_vector(tval.name, tval.typ), flatten_vector(fval.name, fval.typ)
typs = flatten_vector(stmt.name, value.typ)
for i in range(len(typs)):
stmts.append(DefNode(typs[i][0], Mux(value.cond,
Reference(tval_typs[i][0], tval_typs[i][1]), Reference(fval_typs[i][0], fval_typs[i][1]), typs[i][1])))
if isinstance(value, ValidIf):
val = value.value
val_typs = flatten_vector(val.name, val.value)
typs = flatten_vector(stmt.name, value.typ)
for i in range(len(typs)):
stmts.append(DefNode(typs[i][0], ValidIf(value.cond,
Reference(val_typs[i][0], val_typs[i][1]), typs[i][1])))
if isinstance(value, DoPrim):
...
elif isinstance(value.typ, BundleType):
if isinstance(value, Mux):
tval, fval = value.tval, value.fval
tval_typs, fval_typs = flatten_bundle(tval.name, tval.typ), flatten_bundle(fval.name, fval.typ)
typs = flatten_bundle(stmt.name, value.typ)
for i in range(len(typs)):
stmts.append(DefNode(typs[i][0], Mux(value.cond,
Reference(tval_typs[i][0], tval_typs[i][2]), Reference(fval_typs[i][0], fval_typs[i][2]), typs[i][2])))
if isinstance(value, ValidIf):
val = value.value
val_typs = flatten_bundle(val.name, val.value)
typs = flatten_bundle(stmt.name, value.typ)
for i in range(len(typs)):
stmts.append(DefNode(typs[i][0], ValidIf(value.cond,
Reference(val_typs[i][0], val_typs[i][2]), typs[i][2])))
if isinstance(value, DoPrim):
...
else:
stmts.append(stmt)
def expand_aggregate_reg(stmt: Statement, stmts: List[Statement]):
typ = stmt.typ
if isinstance(typ, VectorType):
typs = flatten_vector(stmt.name, typ)
for nx, tx in typs:
init = Reference(nx, tx)
stmts.append(DefRegister(nx, tx, stmt.clock, stmt.reset, init, stmt.info))
elif isinstance(typ, BundleType):
typs = flatten_bundle(stmt.name, stmt.typ)
for nx, _, tx in typs:
stmts.append(DefRegister(nx, tx, stmt.clock, stmt.reset, stmt.init, stmt.info))
else:
stmts.append(stmt)
def expand_aggregate_s(stmts: List[Statement]) -> List[Statement]:
new_stmts = []
for stmt in stmts:
if isinstance(stmt, DefWire):
expand_aggregate_wire(stmt, new_stmts)
elif isinstance(stmt, DefNode):
expand_aggregate_node(stmt, new_stmts)
elif isinstance(stmt, DefRegister):
expand_aggregate_reg(stmt, new_stmts)
else:
new_stmts.append(stmt)
return new_stmts
def expand_aggregate_p(p: Port, ports: List[Port]):
if isinstance(p.typ, VectorType):
typs = flatten_vector(p.name, p.typ)
for nx, tx in typs:
ports.append(Port(nx, p.direction, tx, p.info))
elif isinstance(p.typ, BundleType):
typs = flatten_bundle(p.name, p.typ)
for nx, fx, tx in typs:
dir = p.direction if isinstance(fx, Default) else flip_direction(p.direction)
ports.append(Port(nx, dir, tx, p.info))
else:
ports.append(p)
def expand_aggregate_ps(ps: List[Port]) -> List[Port]:
new_ports = []
for p in ps:
expand_aggregate_p(p, new_ports)
return new_ports
def expand_aggregate_m(m: DefModule) -> DefModule:
if isinstance(m, ExtModule):
return m
if not hasattr(m, 'body') or not isinstance(m.body, Block):
return m
if not hasattr(m.body, 'stmts') or not isinstance(m.body.stmts, list):
return m
return Module(m.name, expand_aggregate_ps(m.ports), Block(expand_aggregate_s(m.body.stmts)), m.typ, m.info)
for m in c.modules:
modules.append(expand_aggregate_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -0,0 +1,97 @@
from typing import List, Dict
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
from pyhcl.passes.utils import get_binary_width
DEFAULT_READ_LATENCY = 0
DEFAULT_WRITE_LATENCY = 1
@dataclass
class ExpandMemory(Pass):
def run(self, c: Circuit):
def get_mem_ports(stmts: List[Statement], writes: Dict[str, List[Statement]], reads: Dict[str, List[Statement]]):
for stmt in stmts:
if isinstance(stmt, DefMemPort):
if stmt.rw is True:
if stmt.mem.name in reads:
reads[stmt.mem.name] = reads[stmt.mem.name] + [stmt.name]
else:
reads[stmt.mem.name] = [stmt.name]
else:
if stmt.mem.name in writes:
writes[stmt.mem.name] = writes[stmt.mem.name] + [stmt.name]
else:
writes[stmt.mem.name] = [stmt.name]
def expand_mem_port(stmts: List[Statement], target: Statement):
addr_width = IntWidth(get_binary_width(target.mem.typ.size))
# addr
stmts.append(Connect(
SubField(SubField(Reference(target.mem.name, UIntType(addr_width)),target.name, UIntType(addr_width)), 'addr', UIntType(addr_width)),
UIntLiteral(target.index.value, addr_width)))
# en
stmts.append(Connect(
SubField(SubField(Reference(target.mem.name, UIntType(IntWidth(1))),target.name, UIntType(IntWidth(1))), 'en', UIntType(IntWidth(1))),
UIntLiteral(1, IntWidth(1))))
# clk
stmts.append(Connect(
SubField(SubField(Reference(target.mem.name, ClockType()),target.name, ClockType()), 'clk', ClockType()),
target.clk))
# mask
if target.rw is False:
stmts.append(Connect(
SubField(SubField(Reference(target.mem.name, UIntType(IntWidth(1))),target.name, UIntType(IntWidth(1))), 'mask', UIntType(IntWidth(1))),
UIntLiteral(1, IntWidth(1))))
def expand_memory_e(s: Statement, ports: Dict[str, Statement]) -> Statement:
loc, expr = s.loc, s.expr
if isinstance(loc, Reference) and loc.name in ports:
loc = SubField(SubField(Reference(ports[loc.name].mem.name, loc.typ), loc.name, loc.typ), 'data', loc.typ)
elif isinstance(expr, Reference) and expr.name in ports:
expr = SubField(SubField(Reference(ports[expr.name].mem.name, expr.typ), expr.name, expr.typ), 'data', expr.typ)
return Connect(loc, expr, s.info, s.blocking, s.bidirection, s.mem)
def expand_memory_s(stmts: List[Statement]) -> List[Statement]:
new_stmts: List[Statement] = []
writes: Dict[str, List[Statement]] = {}
reads: Dict[str, List[Statement]] = {}
ports: Dict[str, List[Statement]] = {}
get_mem_ports(stmts, writes, reads)
for stmt in stmts:
if isinstance(stmt, DefMemory):
new_stmts.append(WDefMemory(
stmt.name,
stmt.memType,
stmt.memType.typ,
stmt.memType.size,
DEFAULT_READ_LATENCY,
DEFAULT_WRITE_LATENCY,
reads[stmt.name],
writes[stmt.name]))
elif isinstance(stmt, DefMemPort):
expand_mem_port(new_stmts, stmt)
ports[stmt.name] = stmt
elif isinstance(stmt, Connect):
new_stmts.append(expand_memory_e(stmt, ports))
else:
new_stmts.append(stmt)
return new_stmts
def expand_memory_m(m: DefModule) -> DefModule:
return Module(
m.name,
m.ports,
Block(expand_memory_s(m.body.stmts)),
m.typ,
m.info
)
new_modules = []
for m in c.modules:
if isinstance(m, Module):
new_modules.append(expand_memory_m(m))
else:
new_modules.append(m)
return Circuit(new_modules, c.main, c.info)

View File

@ -0,0 +1,122 @@
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from typing import List, Dict
from dataclasses import dataclass
from pyhcl.passes._pass import Pass
@dataclass
class ExpandSequential(Pass):
def run(self, c: Circuit):
modules: List[Module] = []
blocks: List[Statement] = []
block_map: Dict[str, List[Statement]] = {}
clock_map: Dict[str, Expression] = {}
def get_ref_name(e: Expression) -> str:
if isinstance(e, SubAccess):
return get_ref_name(e.expr)
elif isinstance(e, SubField):
return get_ref_name(e.expr)
elif isinstance(e, SubIndex):
return get_ref_name(e.expr)
elif isinstance(e, Reference):
return e.name
def expand_sequential_s(s: Statement, stmts: List[Statement], reg_map: Dict[str, DefRegister]):
if isinstance(s, Conditionally):
conseq_seq_map, conseq_com = expand_sequential_s(s.conseq, stmts, reg_map)
alt_seq_map, alt_com = expand_sequential_s(s.alt, stmts, reg_map)
for k in conseq_seq_map:
if k not in alt_seq_map:
alt_seq_map[k] = EmptyStmt()
com = Conditionally(s.pred, conseq_com, alt_com, s.info)
if isinstance(conseq_com, EmptyStmt):
if isinstance(alt_com, EmptyStmt):
com = EmptyStmt()
else:
com = Conditionally(DoPrim(Not(), [s.pred], [], s.pred.typ))
return {k: Conditionally(s.pred, v, alt_seq_map[k], s.info) for k, v in conseq_seq_map.items()}, com
elif isinstance(s, Block):
com_stmts: List[Statement] = []
seq_stmts_map: Dict[str, List[Statement]] = {}
for sx in s.stmts:
if isinstance(sx, Connect) and get_ref_name(sx.loc) in reg_map:
reg = reg_map[get_ref_name(sx.loc)]
if reg.clock.verilog_serialize() not in clock_map:
clock_map[reg.clock.verilog_serialize()] = reg.clock
if reg.clock.verilog_serialize() not in seq_stmts_map:
seq_stmts_map[reg.clock.verilog_serialize()] = []
seq_stmts_map[reg.clock.verilog_serialize()].append(Connect(sx.loc, sx.expr, sx.info, False, sx.bidirection, sx.mem))
elif isinstance(sx, Connect) and get_ref_name(sx.loc) not in reg_map:
com_stmts.append(sx)
elif isinstance(sx, Conditionally):
seq_when_map, com_when = expand_sequential_s(sx, stmts, reg_map)
for k in seq_when_map:
if k not in seq_stmts_map:
seq_stmts_map[k] = []
seq_stmts_map[k].append(seq_when_map[k])
com_stmts.append(com_when)
else:
stmts.append(sx)
return {k: Block(v) if len(v) > 0 else EmptyStmt() for k, v in seq_stmts_map.items()}, \
Block(com_stmts) if len(com_stmts) > 0 else EmptyStmt()
else:
return {}, s
def expand_sequential(stmts: List[Statement]) -> List[Statement]:
reg_map: Dict[str, DefRegister] = {sx.name: sx for sx in stmts if isinstance(sx, DefRegister)}
mem_map: Dict[str, DefMemPort] = {sx.name: sx for sx in stmts if isinstance(sx, DefMemPort)}
new_stmts: List[Statement] = []
for stmt in stmts:
if isinstance(stmt, Conditionally):
seq_map, com = expand_sequential_s(stmt, new_stmts, reg_map)
if not isinstance(com, EmptyStmt):
new_stmts.append(com)
for k in seq_map:
if k not in block_map:
block_map[k] = []
if not isinstance(seq_map[k], EmptyStmt):
block_map[k].append(seq_map[k])
else:
new_stmts.append(stmt)
reset_map: Dict[str, List[Statement]] = {}
reset_sign_map: Dict[str, List[Expression]] = {}
for reg_name in reg_map:
reg: DefRegister = reg_map[reg_name]
if reg.init is not None:
if reg.clock.verilog_serialize() not in reset_map:
reset_map[reg.clock.verilog_serialize()] = []
reset_map[reg.clock.verilog_serialize()].append(Connect(Reference(reg.name, reg.typ), reg.init, reg.info, False))
if reg.clock.verilog_serialize not in reset_sign_map:
reset_sign_map[reg.clock.verilog_serialize()] = []
reset_sign_map[reg.clock.verilog_serialize()].append(reg.reset)
for k in reset_map:
if len(reset_map[k]) > 0:
for rs in reset_sign_map[k]:
block_map[k].append(Conditionally(rs, Block(reset_map[k]), EmptyStmt()))
for k in mem_map:
mem: Reference = mem_map[k].mem
sig = DoPrim(And(), [Reference(f"{mem.name}_{k}_en", UIntType(IntWidth(1))),
Reference(f"{mem.name}_{k}_mask", UIntType(IntWidth(1)))], [], UIntType(IntWidth(1)))
con = Connect(SubAccess(mem, Reference(f"{mem.name}_{k}_addr", mem_map[k].index.typ), mem.typ),
Reference(f"{mem.name}_{k}_data", mem.typ),mem_map[k].info, False)
if mem_map[k].rw is False:
if mem_map[k].clk.verilog_serialize() not in block_map:
block_map[mem_map[k].clk.verilog_serialize()] = []
if mem_map[k].clk.verilog_serialize() not in clock_map:
clock_map[mem_map[k].clk.verilog_serialize()] = mem_map[k].clk
block_map[mem_map[k].clk.verilog_serialize()].append(Conditionally(sig, Block([con]), EmptyStmt()))
for k in block_map:
new_stmts.append(AlwaysBlock(block_map[k], clock_map[k]))
return new_stmts
def expand_sequential_m(m: DefModule) -> DefModule:
if isinstance(m, Module):
return Module(m.name, m.ports, Block(expand_sequential(m.body.stmts)), m.typ)
else:
return m
for m in c.modules:
modules.append(expand_sequential_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -0,0 +1,70 @@
from typing import List
from numpy import isin
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
from pyhcl.passes.utils import AutoName
@dataclass
class ExpandWhens(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[DefModule] = []
def auto_gen_name():
return AutoName.auto_gen_name()
def last_name():
return AutoName.last_name()
def expand_whens(s: Statement, stmts: List[Statement], refs: Dict[str, List[Statement]], pred: Expression = None):
if isinstance(s, Conditionally):
expand_whens(s.conseq, stmts, refs, s.pred)
expand_whens(s.alt, stmts, refs, DoPrim(Not(), [s.pred], [], s.pred.typ))
elif isinstance(s, Block):
for sx in s.stmts:
expand_whens(sx, stmts, refs, pred)
elif isinstance(s, EmptyStmt):
...
elif isinstance(s, Connect):
if s.loc.serialize() not in refs:
refs[s.loc.serialize()] = []
refs[s.loc.serialize()].append(Conditionally(pred, Block([s]), EmptyStmt()))
else:
stmts.append(s)
def expand_whens_s(ss: List[Statement]) -> List[Statement]:
stmts: List[Statement] = []
refs: Dict[str, List[Statement]] = {}
for sx in ss:
if isinstance(sx, Conditionally):
expand_whens(sx, stmts, refs)
else:
stmts.append(sx)
for sx in refs.values():
if len(sx) <= 1:
sxx = sx.pop()
con = sxx.conseq.stmts.pop()
stmts.append(Connect(con.loc, ValidIf(sxx.pred, con.expr, con.expr.typ)))
else:
sxx = sx.pop()
con = sxx.conseq.stmts.pop()
stmts.append(DefNode(auto_gen_name(), ValidIf(sxx.pred, con.expr, con.expr.typ)))
while len(sx) > 1:
sxx = sx.pop()
con = sxx.conseq.stmts.pop()
stmts.append(DefNode(auto_gen_name(), Mux(sxx.pred, con.expr, Reference(AutoName.last_name(), con.expr.typ), con.expr.typ)))
sxx = sx.pop()
con = sxx.conseq.stmts.pop()
stmts.append(Connect(con.loc, Mux(sxx.pred, con.expr, Reference(last_name(), con.expr.typ), con.expr.typ)))
return stmts
def expand_whens_m(m: DefModule) -> DefModule:
if isinstance(m, Module):
return Module(m.name, m.ports, Block(expand_whens_s(m.body.stmts)), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(expand_whens_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import List, Dict
from pyhcl.ir.low_ir import *
from pyhcl.passes._pass import Pass
@dataclass
class HandleInstance(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[DefModule] = []
refs: Dict[str, List[Port]] = {m.name: m.ports for m in c.modules}
def handle_instance_s(s: Statement):
if isinstance(s, Conditionally):
return Conditionally(s.pred, handle_instance_s(s.conseq), handle_instance_s(s.alt), s.info)
elif isinstance(s, Block):
return Block([handle_instance_s(sx) for sx in s.stmts])
elif isinstance(s, DefInstance):
if s.module in refs:
return DefInstance(s.name, s.module, refs[s.module], s.info)
else:
return s
else:
return s
def handle_instance_m(m: DefModule):
if isinstance(m, Module):
return Module(m.name, m.ports, handle_instance_s(m.body), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(handle_instance_m(m))
return Circuit(modules, c.main, c.info)

57
pyhcl/passes/optimize.py Normal file
View File

@ -0,0 +1,57 @@
from typing import List, Dict
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
@dataclass
class Optimize(Pass):
def run(self, c: Circuit):
def get_name(e: Expression) -> str:
if isinstance(e, (SubField, SubIndex, SubAccess)):
return get_name(e.expr)
else:
return e.name
def optimize_s(stmts: List[Statement]) -> List[Statement]:
defwires: Dict[str, Statement] = {}
connects: Dict[str, Statement] = {}
nodes: Dict[str, Statement] = {}
new_stmts: List[Statement] = []
for stmt in stmts:
if isinstance(stmt, DefWire):
defwires[stmt.name] = stmt
if isinstance(stmt, Connect):
connects[get_name(stmt.loc)] = stmt
for defwire in defwires.keys():
if defwire in connects:
nodes[defwire] = DefNode(defwire, connects[defwire].expr)
for stmt in stmts:
if isinstance(stmt, DefWire) and stmt.name in nodes:
new_stmts.append(nodes[stmt.name])
elif isinstance(stmt, Connect) and get_name(stmt.loc) in nodes:
...
else:
new_stmts.append(stmt)
return new_stmts
def optimize_m(m: DefModule) -> DefModule:
return Module(
m.name,
m.ports,
Block(optimize_s(m.body.stmts)),
m.typ,
m.info
)
new_modules = []
for m in c.modules:
if isinstance(m, Module):
new_modules.append(optimize_m(m))
else:
new_modules.append(m)
return Circuit(new_modules, c.main, c.info)

View File

@ -0,0 +1,66 @@
from pyhcl.ir.low_ir import *
from typing import List
from dataclasses import dataclass
from pyhcl.passes._pass import Pass
@dataclass
class RemoveAccess(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[Module] = []
def remove_access(e: Expression, type: Type = None, name: str = None) -> Expression:
if isinstance(e, Reference):
return Reference(f"{e.name}_{name}", type)
elif isinstance(e, SubIndex):
if name is None:
return remove_access(e.expr, type, f"{e.value}")
else:
return remove_access(e.expr, type, f"{e.value}_{name}")
elif isinstance(e, SubField):
if name is None:
return remove_access(e.expr, type, f"{e.name}")
else:
return remove_access(e.expr, type, f"{e.name}_{name}")
else:
return e
def remove_access_e(e: Expression) -> Expression:
if isinstance(e, (SubIndex, SubField)):
return remove_access(e, e.typ)
elif isinstance(e, ValidIf):
return ValidIf(remove_access_e(e.cond), remove_access_e(e.value), e.typ)
elif isinstance(e, Mux):
return Mux(remove_access_e(e.cond), remove_access_e(e.tval), remove_access_e(e.fval), e.typ)
elif isinstance(e, DoPrim):
return DoPrim(e.op, [remove_access_e(arg) for arg in e.args], e.consts, e.typ)
else:
return e
def remove_access_s(s: Statement) -> Statement:
if isinstance(s, Block):
stmts: List[Statement] = []
for sx in s.stmts:
stmts.append(remove_access_s(sx))
return Block(stmts)
elif isinstance(s, Conditionally):
return Conditionally(remove_access_e(s.pred), remove_access_s(s.conseq), remove_access_s(s.alt), s.info)
elif isinstance(s, DefRegister):
return DefRegister(s.name, s.typ, remove_access_e(s.clock), remove_access_e(s.reset), remove_access_e(s.init), s.info)
elif isinstance(s, DefNode):
return DefNode(s.name, remove_access_e(s.value), s.info)
elif isinstance(s, DefMemPort):
return DefMemPort(s.name, s.mem, remove_access_e(s.index), remove_access_e(s.clk), s.rw, s.info)
elif isinstance(s, Connect):
return Connect(remove_access_e(s.loc), remove_access_e(s.expr), s.info)
else:
return s
def remove_access_m(m: DefModule) -> DefModule:
if isinstance(m, Module):
return Module(m.name, m.ports, remove_access_s(m.body), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(remove_access_m(m))
return Circuit(modules, c.main, c.info)

View File

@ -0,0 +1,140 @@
from sqlite3 import connect
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.passes._pass import Pass
from pyhcl.passes.utils import get_binary_width, AutoName
@dataclass
class ReplaceSubaccess(Pass):
def run(self, c: Circuit) -> Circuit:
modules: List[DefModule] = []
def has_access(e: Expression) -> bool:
if isinstance(e, SubAccess):
return True
elif isinstance(e, (SubField, SubIndex)):
return has_access(e.expr)
else:
return False
def get_ref_name(e: Expression) -> str:
if isinstance(e, SubAccess):
return get_ref_name(e.expr)
elif isinstance(e, SubField):
return get_ref_name(e.expr)
elif isinstance(e, SubIndex):
return get_ref_name(e.expr)
elif isinstance(e, Reference):
return e.name
def auto_gen_name():
return AutoName.auto_gen_name()
def last_name():
return AutoName.last_name()
def replace_subaccess(e: Expression):
cons: List[Expression] = []
exps: List[Expression] = []
if isinstance(e, SubAccess):
xcons, xexps = replace_subaccess(e.expr)
if len(cons) == 0 and len(xexps) == 0:
if isinstance(e.expr.typ, VectorType):
for i in range(e.expr.typ.size):
cons.append(DoPrim(Eq(), [e.index, UIntLiteral(i, IntWidth(get_binary_width(e.expr.typ.size)))], [], UIntType(IntWidth(1))))
exps.append(SubIndex("", e.expr, i, e.typ))
else:
exps.append(e)
else:
if isinstance(e.expr.typ, VectorType):
for i in range(e.expr.typ.size):
for xcon in xcons:
cons.append(DoPrim(And(), [xcon, DoPrim(Eq(), [e.index, UIntLiteral(i, IntWidth(get_binary_width(e.expr.typ.size)))],
[], UIntType(IntWidth(1)))], [], UIntType(IntWidth(1))))
for xexp in xexps:
exps.append(SubIndex("", xexp, i, e.typ))
else:
cons, exps = xcons, xexps
elif isinstance(e, SubField):
xcons, xexps = replace_subaccess(e.expr)
cons = xcons
for xexp in xexps:
exps.append(SubField(xexp, e.name, e.typ))
elif isinstance(e, SubIndex):
xcons, xexps = replace_subaccess(e.expr)
cons = xcons
for xexp in xexps:
exps.append(SubIndex("", xexp, e.value, e.typ))
return cons, exps
def replace_subaccess_e(e: Expression, stmts: List[Statement], is_sink: bool = False, source: Expression = None) -> Expression:
if isinstance(e, ValidIf):
return ValidIf(replace_subaccess_e(e.cond, stmts), replace_subaccess_e(e.value, stmts), e.typ)
elif isinstance(e, Mux):
return Mux(replace_subaccess_e(e.cond, stmts), replace_subaccess_e(e.tval, stmts), replace_subaccess_e(e.fval, stmts), e.typ)
elif isinstance(e, DoPrim):
return DoPrim(e.op, [replace_subaccess_e(arg, stmts) for arg in e.args], e.consts, e.typ)
elif isinstance(e, (SubAccess, SubField, SubIndex)) and has_access(e):
if is_sink:
cons, exps = replace_subaccess(e)
gen_nodes: Dict[str, DefNode] = {}
for i in range(len(cons)):
stmts.append(Connect(exps[i], Mux(cons[i], source, exps[i], e.typ)))
return
else:
cons, exps = replace_subaccess(e)
gen_nodes: Dict[str, DefNode] = {}
for i in range(len(cons)):
if i == 0:
name = auto_gen_name()
gen_node = DefNode(name, ValidIf(cons[i], exps[i], e.typ))
gen_nodes[name] = gen_node
else:
last_node = gen_nodes[last_name()]
name = auto_gen_name()
gen_node = DefNode(name, Mux(cons[i], exps[i], last_node.value.value if i == 1 else Reference(last_node.name, last_node.value.typ), e.typ))
stmts.append(gen_node)
gen_nodes[name] = gen_node
return Reference(gen_nodes[last_name()].name, e.typ)
else:
return e
def replace_subaccess_s(s: Statement) -> Statement:
if isinstance(s, Block):
stmts: List[Statement] = []
for stmt in s.stmts:
if isinstance(stmt, Connect):
expr = replace_subaccess_e(stmt.expr, stmts)
loc = replace_subaccess_e(stmt.loc, stmts, True, expr)
if loc is not None:
stmts.append(Connect(loc, expr, stmt.info, stmt.blocking, stmt.bidirection, stmt.mem))
else:
...
# stmts.append(stmt)
elif isinstance(stmt, DefNode):
stmts.append(DefNode(stmt.name, replace_subaccess_e(stmt.value, stmts), stmt.info))
elif isinstance(stmt, DefRegister):
stmts.append(DefRegister(stmt.name, stmt.typ, stmt.clock, stmt.reset, replace_subaccess_e(stmt.init, stmts), stmt.info))
elif isinstance(stmt, Conditionally):
stmts.append(replace_subaccess_s(stmt))
else:
stmts.append(stmt)
return Block(stmts)
elif isinstance(s, EmptyStmt):
return EmptyStmt()
elif isinstance(s, Conditionally):
return Conditionally(s.pred, replace_subaccess_s(s.conseq), replace_subaccess_s(s.alt), s.info)
else:
return s
def replace_subaccess_m(m: DefModule) -> DefModule:
if isinstance(m, Module):
return Module(m.name, m.ports, replace_subaccess_s(m.body), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(replace_subaccess_m(m))
return Circuit(modules, c.main, c.info)

252
pyhcl/passes/utils.py Normal file
View File

@ -0,0 +1,252 @@
import math
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.passes.wir import DuplexFlow, Flow, SinkFlow, SourceFlow, UnknownFlow
from pyhcl.passes._pass import PassException
# CheckForm utils
class ModuleGraph:
nodes: Dict[str, set] = {}
def add(self, parent: str, child: str) -> List[str]:
if parent not in self.nodes.keys():
self.nodes[parent] = set()
self.nodes[parent].add(child)
return self.path_exists(child, parent, [child, parent])
def path_exists(self, child: str, parent: str, path: List[str] = None) -> List[str]:
cx = self.nodes[child] if child in self.nodes.keys() else None
if cx is not None:
if parent in cx:
return [parent] + path
else:
for cxx in cx:
new_path = self.path_exists(cxx, parent, [cxx] + path)
if len(new_path) > 0:
return new_path
return None
def is_max(w1: Width, w2: Width):
return IntWidth(max(w1.width, w2.width))
def to_flow(d: Direction) -> Flow:
if isinstance(d, Input):
return SourceFlow()
if isinstance(d, Output):
return SinkFlow()
def flow(e: Expression) -> Flow:
if isinstance(e, DoPrim):
return SourceFlow()
elif isinstance(e, UIntLiteral):
return SourceFlow()
elif isinstance(e, SIntLiteral):
return SourceFlow()
elif isinstance(e, Mux):
return SourceFlow()
elif isinstance(e, ValidIf):
return SourceFlow()
else:
raise PassException(f'flow: shouldn\'t be here - {e}')
def mux_type_and_widths(e1, e2) -> Type:
return mux_type_and_width(e1.typ, e2.typ)
def mux_type_and_width(t1: Type, t2: Type) -> Type:
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
return ClockType()
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
return AsyncResetType()
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
return UIntType(is_max(t1.width, t2.width))
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
return VectorType(mux_type_and_width(t1.typ, t2.typ), t1.size)
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
return BundleType(map(lambda f1, f2: Field(f1.name, f1.flip, mux_type_and_width(f1.typ, f2.typ)), list(zip(t1.fields, t2.fields))))
else:
return UnknownType
def create_exps(e: Expression) -> List[Expression]:
if isinstance(e, Mux):
e1s = create_exps(e.tval)
e2s = create_exps(e.fval)
return list(map(lambda e1, e2: Mux(e.cond, e1, e2, mux_type_and_widths(e1, e2)), list(zip(e1s, e2s))))
elif isinstance(e, ValidIf):
return list(map(lambda e1: ValidIf(e.cond, e1, e1.typ), create_exps(e.value)))
else:
if isinstance(e.typ, GroundType):
return [e]
elif isinstance(e.typ, BundleType):
exprs = []
for f in e.typ.fields:
exprs = exprs + create_exps(Reference(f'{e.name}_{f.name}', f.typ))
return exprs
elif isinstance(e.typ, VectorType):
exprs = []
for i in range(e.typ.size):
exprs = exprs + create_exps(Reference(f'{e.name}_{i}', f.typ))
return exprs
def get_info(s: Statement) -> Info:
if hasattr(s, 'info'):
return s.info
else:
return NoInfo
def has_flip(typ: Type) -> bool:
if isinstance(typ, BundleType):
for f in typ.fields:
if isinstance(f, Flip):
return True
else:
return has_flip(f.typ)
elif isinstance(typ, VectorType):
return has_flip(typ.typ)
else:
return False
# InterTypes utils
def to_flip(d: Direction):
if isinstance(d, Output):
return Default()
if isinstance(d, Input):
return Flip()
def module_type(m: DefModule) -> BundleType:
fields = [Field(p.name, to_flip(p.direction), p.typ) for p in m.ports]
return BundleType(fields)
def field_type(v: Type, s: str) -> Type:
if isinstance(v, BundleType):
def match_type(f: Field) -> Type:
if f is None:
return UnknownType
return f.typ
for f in v.fields:
if f.name == s:
return match_type(f)
return UnknownType()
def sub_type(v: Type) -> Type:
if isinstance(v, VectorType):
return v.typ
return UnknownType()
def mux_type(e1: Expression, e2: Expression) -> Type:
return mux_types(e1.typ, e2.typ)
def mux_types(t1: Type, t2: Type) -> Type:
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
return ClockType()
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
return AsyncResetType()
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
return t1
elif isinstance(t1, SIntType) and isinstance(t2, SIntType):
return t2
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
return VectorType(mux_types(t1.typ, t2.typ), t1.size)
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
return BundleType(list(map(lambda f1, f2: Field(f1.name, f1.flip, mux_types(f1.typ, f2.typ)), list(zip(t1.fields, t2.fields)))))
else:
return UnknownType()
def get_or_else(cond, a, b):
return a if cond else b
# CheckTypes utils
def swp_flow(f: Flow) -> Flow:
if isinstance(f, UnknownFlow):
return UnknownFlow()
elif isinstance(f, SourceFlow):
return SinkFlow()
elif isinstance(f, SinkFlow):
return SourceFlow()
elif isinstance(f, DuplexFlow):
return DuplexFlow()
else:
return Flow()
def swp_direction(d: Direction) -> Direction:
if isinstance(d, Input):
return Output()
elif isinstance(d, Output):
return Input()
else:
return Direction()
def swp_orientation(o: Orientation) -> Orientation:
if isinstance(o, Default):
return Default()
elif isinstance(o, Flip):
return Flip()
else:
return Orientation()
def times_d_flip(d: Direction, flip: Orientation) -> Direction:
if isinstance(flip, Default):
return d
elif isinstance(flip, Flip):
return swp_direction(d)
def times_g_d(g: Flow, d: Direction) -> Direction:
return times_d_g(d, g)
def times_d_g(d: Direction, g: Flow) -> Direction:
if isinstance(g, SinkFlow):
return d
elif isinstance(g, SourceFlow):
return swp_flow(d)
def times_g_flip(g: Flow, flip: Orientation) -> Flow:
return times_flip_g(flip, g)
def times_flip_g(flip: Orientation, g: Flow) -> Flow:
if isinstance(flip, Default):
return g
elif isinstance(flip, Flip):
return swp_flow(g)
def times_f_f(f1: Orientation, f2: Orientation) -> Orientation:
if isinstance(f2, Default):
return f1
elif isinstance(f2, Flip):
return swp_orientation(f1)
# CheckWidth Utils
def get_binary_width(target):
width = 1
while target / 2 >= 1:
width += 1
target = math.floor(target / 2)
return width
def get_width(w: Width) -> int:
if isinstance(w, UnknownWidth):
return 0
return w.width
def has_width(t: Type) -> bool:
if hasattr(t, 'width'):
return True
else:
return False
class AutoName:
endwith: int = -1
names: List[str] = []
@staticmethod
def auto_gen_name():
AutoName.endwith += 1
gen_name = f"GEN_{AutoName.endwith}"
AutoName.names.append(gen_name)
return gen_name
@staticmethod
def last_name():
return AutoName.names[-1]

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass
from pyhcl.ir.low_ir import *
from pyhcl.passes._pass import Pass
from typing import List, Dict
from pyhcl.passes.utils import AutoName
@dataclass
class VerilogOptimize(Pass):
def run(self, c: Circuit):
modules: List[DefModule] = []
def auto_gen_node(s):
return isinstance(s, DefNode) and s.name.startswith("_T")
def get_name(e: Expression) -> str:
if isinstance(e, (SubAccess, SubField, SubIndex)):
return get_name(e.expr)
elif isinstance(e, Reference):
return e.verilog_serialize()
def verilog_optimize_e(expr: Expression, node_map: Dict[str, Statement], filter_nodes: set, stmts: List[Statement]) -> Expression:
if isinstance(expr, (UIntLiteral, SIntLiteral)):
return expr
elif isinstance(expr, Reference):
en = get_name(expr)
if en in node_map:
filter_nodes.add(en)
return verilog_optimize_e(node_map[en].value, node_map, filter_nodes, stmts)
else:
return expr
elif isinstance(expr, (SubField, SubIndex, SubAccess)):
return expr
elif isinstance(expr, Mux):
return Mux(
verilog_optimize_e(expr.cond, node_map, filter_nodes, stmts),
verilog_optimize_e(expr.tval, node_map, filter_nodes, stmts),
verilog_optimize_e(expr.fval, node_map, filter_nodes, stmts), expr.typ)
elif isinstance(expr, ValidIf):
return ValidIf(
verilog_optimize_e(expr.cond, node_map, filter_nodes, stmts),
verilog_optimize_e(expr.value, node_map, filter_nodes, stmts), expr.typ)
elif isinstance(expr, DoPrim):
args = list(map(lambda arg: verilog_optimize_e(arg, node_map, filter_nodes, stmts), expr.args))
if isinstance(expr.op, Bits) and isinstance(args[0], DoPrim):
name = AutoName.auto_gen_name()
stmts.append(DefNode(name, args[0]))
args = [Reference(name, args[0].typ)]
return DoPrim(expr.op, args, expr.consts, expr.typ)
else:
return expr
def verilog_optimize_s(stmt: Statement, node_map: Dict[str, Statement], filter_nodes: set, stmts: List[Statement] = None) -> Statement:
if isinstance(stmt, Block):
node_map = {**node_map ,**{sx.name: sx for sx in stmt.stmts if auto_gen_node(sx)}}
cat_stmts = []
for sx in stmt.stmts:
if isinstance(sx, Connect):
cat_stmts.append(Connect(verilog_optimize_e(sx.loc, node_map, filter_nodes, cat_stmts),
verilog_optimize_e(sx.expr, node_map, filter_nodes, cat_stmts), sx.info, sx.blocking, sx.bidirection, sx.mem))
elif isinstance(sx, DefNode):
cat_stmts.append(DefNode(sx.name, verilog_optimize_e(sx.value, node_map, filter_nodes, cat_stmts), sx.info))
elif isinstance(sx, Conditionally):
cat_stmts.append(Conditionally(verilog_optimize_e(sx.pred, node_map, filter_nodes, cat_stmts), verilog_optimize_s(sx.conseq, node_map, filter_nodes, cat_stmts),
verilog_optimize_s(sx.alt, node_map, filter_nodes, cat_stmts), sx.info))
else:
cat_stmts.append(sx)
cat_stmts = [sx for sx in cat_stmts if not (isinstance(sx, DefNode) and sx.name in filter_nodes)]
return Block(cat_stmts)
elif isinstance(stmt, Conditionally):
return Conditionally(verilog_optimize_e(stmt.pred, node_map, filter_nodes, stmts), verilog_optimize_s(stmt.conseq, node_map, filter_nodes, stmts),
verilog_optimize_s(stmt.alt, node_map, filter_nodes, stmts), stmt.info)
else:
return stmt
def verilog_optimize_m(m: DefModule) -> DefModule:
node_map: Dict[str, DefNode] = {}
filter_nodes: set = set()
if isinstance(m, Module):
return Module(m.name, m.ports, verilog_optimize_s(m.body, node_map, filter_nodes), m.typ, m.info)
else:
return m
for m in c.modules:
modules.append(verilog_optimize_m(m))
return Circuit(modules, c.main, c.info)

81
pyhcl/passes/wir.py Normal file
View File

@ -0,0 +1,81 @@
from abc import ABC
from pyhcl.ir.low_ir import *
class Flow(ABC):
...
class SourceFlow(Flow):
...
class SinkFlow(Flow):
...
class DuplexFlow(Flow):
...
class UnknownFlow(Flow):
...
class Kind(ABC):
...
class PortKind(Kind):
...
class VarWidth(Width):
name: str
def serialize(self):
return self.name
class WrappedType(Type):
def __init__(self, t: Type):
self.t = t
def __eq__(self, o):
if isinstance(o, WrappedType):
return WrappedType.compare(self.t, o.t)
else:
return False
@staticmethod
def compare(sink: Type, source: Type):
def legal_reset_type(self, typ: Type) -> bool:
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
return typ.width.width == 1
elif isinstance(typ, AsyncResetType):
return True
elif isinstance(typ, ResetType):
return True
else:
return False
if isinstance(sink, UIntType) and isinstance(source, UIntType):
return True
elif isinstance(sink, SIntType) and isinstance(source, SIntType):
return True
elif isinstance(sink, ClockType) and isinstance(source, ClockType):
return True
elif isinstance(sink, AsyncResetType) and isinstance(source, AsyncResetType):
return True
elif isinstance(sink, ResetType):
return legal_reset_type(source)
elif isinstance(source, ResetType):
return legal_reset_type(sink)
elif isinstance(sink, VectorType) and isinstance(source, VectorType):
return sink.size == source.size and WrappedType.compare(sink.typ, source.typ)
elif isinstance(sink, BundleType) and isinstance(source, BundleType):
final = True
for f1, f2 in list(zip(sink.fields, source.fields)):
f1_final = WrappedType.compare(f2.typ, f1.typ) if isinstance(f1.flip, Flip) else WrappedType.compare(f1.typ, f2.typ)
final = f1.flip == f2.flip and f1.name == f2.name and f1_final and final
return len(sink.fields) == len(source.fields) and final
else:
return False
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...

View File

@ -120,15 +120,16 @@ class Simlite(object):
with open(f"./simulation/{self.dut_name}-harness.cpp", "w+") as f:
f.write(harness_code)
# 在simulation文件夹下创建dut_name-harness.fir写入firrtl代码
with open(f"./simulation/{self.dut_name}.fir", "w+") as f:
f.write(self.low_module.serialize())
# 调用FIRRTL工具链
# with open(f"./simulation/{self.dut_name}.fir", "w+") as f:
# f.write(self.low_module.serialize())
# 调用firrtl命令传入firrtl代码得到verilog代码
# firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog
# print(f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
os.system(
f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
# os.system(
# f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
# # 调用PyHCL编译链
with open(f"./simulation/{self.dut_name}.v", "w+") as f:
f.write(Verilog(self.low_module).emit())
vfn = "{}.v".format(self.dut_name) # {dut_name}.v
hfn = "{}-harness.cpp".format(self.dut_name) # {dut_name}-harness.cpp

0
pyhcl/tester/__init__.py Normal file
View File

View File

@ -0,0 +1,74 @@
from abc import ABC, abstractclassmethod
from typing import List
from pyhcl.ir.low_ir import *
from pyhcl.tester.symbol_table import SymbolTable
class ClockStepper(ABC):
@abstractclassmethod
def bump_clock(self):
...
@abstractclassmethod
def run(self):
...
@abstractclassmethod
def get_cycle_count(self):
...
@abstractclassmethod
def combinational_bump(self):
...
class SingleClockStepper(ClockStepper):
def __init__(self, mname: str, symbol: str, executor, table: SymbolTable):
self.mname: str = mname
self.clock_symbol: str = symbol
self.executor = executor
self.table: SymbolTable = table
self.clock_cycles = 0
self.combinational_bumps = 0
def handle_name(self, name):
names = name.split(".")
names.reverse()
return names
def bump_clock(self, mname: str, clock_symbol: str, value: int):
self.table.set_symbol_value(mname, self.handle_name(clock_symbol), value)
self.clock_cycles += 1
def combinational_bump(self, value: int):
self.combinational_bumps += value
def get_cycle_count(self):
return self.clock_cycles
def run(self, steps: int):
def raise_clock():
self.table[self.mname][self.clock_symbol] ^= 1
self.executor.execute(self.mname)
self.combinational_bumps = 0
def lower_clock():
self.table[self.mname][self.clock_symbol] ^= 1
self.combinational_bumps = 0
for _ in range(steps):
if self.executor.get_inputchange():
self.executor.execute(self.mname)
self.clock_cycles += 1
raise_clock()
lower_clock()
class MultiClockStepper(ClockStepper):
# TODO: Add MultiCLockStepper
...

296
pyhcl/tester/compiler.py Normal file
View File

@ -0,0 +1,296 @@
from functools import reduce
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.tester.wir import *
from pyhcl.tester.symbol_table import SymbolTable
from pyhcl.tester.exception import TesterException
from pyhcl.tester.utils import DAG
@dataclass(frozen=True)
class TesterCompiler:
symbol_table: SymbolTable
dags = {}
modules = {}
def get_in_port_name(self, name: str, t: Type, d: Direction) -> List[str]:
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
return [name]
elif isinstance(d, Input) and isinstance(t, (VectorType, MemoryType)):
names = []
pnames = self.get_in_port_name(name, t.typ, d)
for pn in pnames:
for i in range(t.size):
names.append(f"{pn}[{i}]")
return names
elif isinstance(t, BundleType):
names = []
for f in t.fields:
pnames = []
if isinstance(d, Input) and isinstance(f.flip, Default):
pnames += self.get_in_port_name(f.name, f.typ, d)
elif isinstance(d, Output) and isinstance(f.flip, Flip):
pnames += self.get_in_port_name(f.name, f.typ, Input())
for pn in pnames:
names.append(f"{name}.{pn}")
return names
else:
return []
def get_out_port_name(self, name: str, t: Type, d: Direction) -> List[str]:
if isinstance(d, Output) and isinstance(t, (UIntType, SIntType)):
return [name]
elif isinstance(d, Output) and isinstance(t, (VectorType, MemoryType)):
names = []
pnames = self.get_in_port_name(name, t.typ, d)
for pn in pnames:
for i in range(t.size):
names.append(f"{pn}[{i}]")
return names
elif isinstance(t, BundleType):
names = []
for f in t.fields:
pnames = []
if isinstance(d, Output) and isinstance(f.flip, Default):
pnames += self.get_out_port_name(f.name, f.typ, d)
elif isinstance(d, Input) and isinstance(f.flip, Flip):
pnames += self.get_out_port_name(f.name, f.typ, Output())
for pn in pnames:
names.append(f"{name}.{pn}")
return names
else:
return []
def gen_dag_nodes(self, name: str, typ: Type):
if isinstance(typ, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
return [name]
elif isinstance(typ, VectorType):
names = []
pre_names = self.gen_dag_nodes(name, typ.typ)
for n in pre_names:
for i in range(typ.size):
names.append(f"{n}[{i}]")
return names
elif isinstance(typ, MemoryType):
names = []
pre_names = self.gen_dag_nodes(name, typ.typ)
for n in pre_names:
for i in range(typ.size):
names.append(f"{n}[{i}]")
return names
elif isinstance(typ, BundleType):
names = []
for f in typ.fields:
pre_names = self.gen_dag_nodes(f.name, f.typ)
for n in pre_names:
names.append(f"{name}.{n}")
return names
def add_dag_node(self, mname: str, name: str, typ: Type):
vs = self.gen_dag_nodes(name, typ)
for v in vs:
self.dags[mname].add_node_if_not_exists(v)
def add_dag_edge(self, mname: str, ind: Expression, dep: Expression):
if isinstance(dep, (Reference, SubAccess, SubIndex, SubField)):
self.dags[mname].add_edge(dep.serialize(), ind.serialize())
elif isinstance(dep, Mux):
self.add_dag_edge(mname, ind, dep.tval)
self.add_dag_edge(mname, ind, dep.fval)
elif isinstance(dep, ValidIf):
self.add_dag_edge(mname, ind, dep.value)
# elif isinstance(dep, DoPrim):
# for arg in dep.args:
# self.add_dag_edge(mname, ind, arg)
def gen_working_ir(self, mname: str, names: list, expr: Expression) -> Expression:
if isinstance(expr, SubField):
names.append(expr.name)
e = self.gen_working_ir(mname, names, expr.expr)
get_func, set_func = e.get_func, e.set_func
return WSubField(expr,
lambda table=None: get_func(table),
lambda s, table=None: set_func(s, table))
elif isinstance(expr, SubAccess):
index = self.gen_working_ir(mname, [], expr.index)
index_get_func = index.get_func
names.append(index_get_func())
e = self.gen_working_ir(mname, names, expr.expr)
get_func, set_func = e.get_func, e.set_func
return WSubAccess(expr,
lambda table=None: get_func(table),
lambda s, table=None: set_func(s, table))
elif isinstance(expr, SubIndex):
# size = expr.expr.typ.size-1
names.append(expr.value)
e = self.gen_working_ir(mname, names, expr.expr)
get_func, set_func = e.get_func, e.set_func
return WSubField(expr,
lambda table=None: get_func(table),
lambda s, table=None: set_func(s, table))
else:
name = expr.serialize()
names.append(name)
if not self.symbol_table.has_symbol(mname, name):
raise TesterException(f"Module [{mname}] Reference {name} is not declared")
return WReference(expr,
lambda table=None: self.symbol_table.get_symbol_value(mname, names, table),
lambda s, table=None: self.symbol_table.set_symbol_value(mname, names, s, table))
def compile_op(self, op, args, consts):
if isinstance(op, Add):
return lambda table=None: reduce(lambda x, y: x.get_value(table) + y.get_value(table), args + consts)
elif isinstance(op, Sub):
return lambda table=None: reduce(lambda x, y: x.get_value(table) - y.get_value(table), args + consts)
elif isinstance(op, Mul):
return lambda table=None: reduce(lambda x, y: x.get_value(table) * y.get_value(table), args + consts)
elif isinstance(op, Div):
return lambda table=None: int(reduce(lambda x, y: x.get_value(table) / y.get_value(table), args + consts))
elif isinstance(op, Rem):
return lambda table=None: reduce(lambda x, y: x.get_value(table) % y.get_value(table), args + consts)
elif isinstance(op, Lt):
return lambda table=None: reduce(lambda x, y: x.get_value(table) < y.get_value(table), args + consts)
elif isinstance(op, Leq):
return lambda table=None: reduce(lambda x, y: x.get_value(table) <= y.get_value(table), args + consts)
elif isinstance(op, Gt):
return lambda table=None: reduce(lambda x, y: x.get_value(table) > y.get_value(table), args + consts)
elif isinstance(op, Geq):
return lambda table=None: reduce(lambda x, y: x.get_value(table) >= y.get_value(table), args + consts)
elif isinstance(op, Eq):
return lambda table=None: reduce(lambda x, y: x.get_value(table) == y.get_value(table), args + consts)
elif isinstance(op, Neq):
return lambda table=None: reduce(lambda x, y: x.get_value(table) != y.get_value(table), args + consts)
elif isinstance(op, Neg):
return lambda table=None: -(args + consts)[0].get_value(table)
elif isinstance(op, Not):
return lambda table=None: not (args + consts)[0].get_value(table)
elif isinstance(op, And):
return lambda table=None: reduce(lambda x, y: x.get_value(table) & y.get_value(table), args + consts)
elif isinstance(op, Or):
return lambda table=None: reduce(lambda x, y: x.get_value(table) | y.get_value(table), args + consts)
elif isinstance(op, Xor):
return lambda table=None: reduce(lambda x, y: x.get_value(table) ^ y.get_value(table), args + consts)
elif isinstance(op, Shl):
return lambda table=None: reduce(lambda x, y: x.get_value(table) << y.get_value(table), args + consts)
elif isinstance(op, Shr):
return lambda table=None: reduce(lambda x, y: x.get_value(table) >> y.get_value(table), args + consts)
elif isinstance(op, Bits):
def bits(args, consts, table=None):
value = '{:032b}'.format(args[0].get_value(table))
value = value[::-1]
value_width = len(value)
lsb = consts[0].get_value(table) if consts[0].get_value(table) < value_width else 0
msb = consts[1].get_value(table) if consts[1].get_value(table) < value_width else 0
value = value[msb: lsb] if lsb > msb else value[lsb]
return int(value, 2)
return lambda table=None: bits(args, consts, table)
elif isinstance(op, Cat):
def cat(args, table):
hi = args[0].get_value(table) if isinstance(args[0].get_value(table), str) else bin(args[0].get_value(table))[2:]
lo = args[1].get_value(table) if isinstance(args[1].get_value(table), str) else bin(args[1].get_value(table))[2:]
# max_len = len(hi) if len(hi) >= len(lo) else len(lo)
# hi = '{:032b}'.format(args[0].get_value(table))[-max_len:]
# lo = '{:032b}'.format(args[1].get_value(table))[-max_len:]
return hi+lo
return lambda table=None: cat(args, table)
elif isinstance(op, (AsUInt, AsSInt)):
return lambda table=None: int(args[0].get_value(table))
elif isinstance(op, AsClock):
return lambda table=None: int(args[0].get_value(table)) % 2
else:
return lambda table=None: None
def compile_e(self, mname: str, expr: Expression) -> Expression:
if isinstance(expr, (Reference, SubField, SubAccess, SubIndex)):
names = []
return self.gen_working_ir(mname, names, expr)
elif isinstance(expr, DoPrim):
args = list(map(lambda arg: self.compile_e(mname, arg), expr.args))
consts = [WInt(const) for const in expr.consts]
return WDoPrim(expr, self.compile_op(expr.op, args, consts))
elif isinstance(expr, Mux):
cond = self.compile_e(mname, expr.cond)
tval = self.compile_e(mname, expr.tval)
fval = self.compile_e(mname, expr.fval)
return WMux(expr, lambda table=None: tval.get_value(table) if cond.get_value(table) else fval.get_value(table))
elif isinstance(expr, ValidIf):
cond = self.compile_e(mname, expr.cond)
value = self.compile_e(mname, expr.value)
return WValidIf(expr, lambda table=None: value.get_value(table) if cond.get_value(table) else None)
elif isinstance(expr, UIntLiteral):
return WUIntLiteral(expr)
elif isinstance(expr, SIntLiteral):
return WSIntLiteral(expr)
else:
return expr
def compile_s(self, mname: str, s: Statement) -> Statement:
if isinstance(s, EmptyStmt):
return EmptyStmt()
elif isinstance(s, Conditionally):
return Conditionally(self.compile_e(mname, s.pred), self.compile_s(mname, s.conseq), self.compile_s(mname, s.alt), s.info)
elif isinstance(s, Block):
insts = [sx for sx in s.stmts if isinstance(sx, DefInstance)]
nodes = [sx for sx in s.stmts if isinstance(sx, DefNode)]
wires = [sx for sx in s.stmts if isinstance(sx, DefWire)]
regs = [sx for sx in s.stmts if isinstance(sx, DefRegister)]
mems = [sx for sx in s.stmts if isinstance(sx, WDefMemory)]
cons = [sx for sx in s.stmts if isinstance(sx, Connect)]
stmts = regs + mems + insts + wires + nodes + cons
return Block([self.compile_s(mname, sx) for sx in stmts])
elif isinstance(s, DefRegister):
self.add_dag_node(mname, s.name, s.typ)
self.symbol_table.set_symbol(mname, s)
return DefRegister(s.name,
s.typ,
self.compile_e(mname, s.clock),
self.compile_e(mname, s.reset),
self.compile_e(mname, s.init),
s.info)
elif isinstance(s, WDefMemory):
self.add_dag_node(mname, s.name, s.memType)
self.symbol_table.set_symbol(mname, s)
return s
elif isinstance(s, DefInstance):
for p in s.ports:
self.add_dag_node(mname, f"{s.name}_{p.name}", p.typ)
self.symbol_table.set_symbol(mname, s)
return s
elif isinstance(s, DefWire):
self.add_dag_node(mname, s.name, s.typ)
self.symbol_table.set_symbol(mname, s)
return s
elif isinstance(s, DefNode):
self.add_dag_node(mname, s.name, s.value.typ)
self.add_dag_edge(mname, Reference(s.name, s.value.typ), s.value)
self.symbol_table.set_symbol(mname, s)
return DefNode(s.name, self.compile_e(mname, s.value), s.info)
elif isinstance(s, Connect):
self.add_dag_edge(mname, s.loc, s.expr)
return Connect(self.compile_e(mname, s.loc), self.compile_e(mname, s.expr), s.info)
else:
return s
def compile_p(self, mname: str, p: Port):
self.add_dag_node(mname, p.name, p.typ)
self.symbol_table.set_symbol(mname, p)
return p
def compile_m(self, m: DefModule):
if isinstance(m, Module):
self.dags[m.name] = DAG()
self.symbol_table.set_module(m.name)
ports = list(map(lambda p: self.compile_p(m.name, p), m.ports))
body = self.compile_s(m.name, m.body)
return Module(m.name, ports, body, m.typ, m.info)
elif isinstance(m, ExtModule):
self.symbol_table.set_module(m.name)
ports = list(map(lambda p: self.compile_p(m.name, p), m.ports))
return ExtModule(m.name, ports, m.defname, m.typ, m.info)
def compile(self, c: Circuit):
for m in c.modules:
self.modules[m.name] = m
modules = list(map(lambda m: self.compile_m(m), c.modules))
return Circuit(modules, c.main, c.info), self.dags

View File

@ -0,0 +1,6 @@
class TesterException(Exception):
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message

244
pyhcl/tester/executer.py Normal file
View File

@ -0,0 +1,244 @@
from typing import Dict, List
from collections import OrderedDict
from copy import deepcopy
from pyhcl.ir.low_ir import *
from pyhcl.ir.low_prim import *
from pyhcl.tester.compiler import TesterCompiler
from pyhcl.tester.symbol_table import SymbolTable
from pyhcl.tester.clock_stepper import SingleClockStepper
from pyhcl.passes.check_form import CheckHighForm
from pyhcl.passes.check_types import CheckTypes
from pyhcl.passes.check_flows import CheckFlow
from pyhcl.passes.check_widths import CheckWidths
from pyhcl.passes.auto_inferring import AutoInferring
from pyhcl.passes.replace_subaccess import ReplaceSubaccess
from pyhcl.passes.expand_aggregate import ExpandAggregate
from pyhcl.passes.expand_whens import ExpandWhens
from pyhcl.passes.expand_memory import ExpandMemory
from pyhcl.passes.handle_instance import HandleInstance
from pyhcl.passes.optimize import Optimize
from pyhcl.passes.remove_access import RemoveAccess
from pyhcl.passes.utils import AutoName
class TesterExecuter:
def __init__(self, circuit: Circuit):
self.circuit = circuit
self.symbol_table = SymbolTable()
self.reg_table = {}
self.mem_table = {}
self.clock_table = {}
self.inputchange = False
def handle_name(self, name):
names = name.split(".")
names.reverse()
return names
def get_inputchange(self):
return self.inputchange
def get_ref_name(self, e: Expression):
if isinstance(e, SubField):
return self.get_ref_name(e.expr)
elif isinstance(e, SubIndex):
return self.get_ref_name(e.expr)
elif isinstance(e, SubAccess):
return self.get_ref_name(e.expr)
else:
return e.name
def execute_stmt(self, m: Module, stmt: Statement, table=None):
if isinstance(stmt, Connect):
if stmt.loc.expr.serialize() in self.reg_table:
if self.reg_table[stmt.loc.expr.serialize()].reset.get_value(table) == 0:
if stmt.expr.get_value(table) is not None:
stmt.loc.set_value(stmt.expr.get_value(table), table)
else:
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.loc.expr.name),
self.reg_table[stmt.loc.expr.serialize()].init.get_value(table), table)
elif stmt.loc.expr.serialize() in self.mem_table:
mem_data = stmt.loc.expr.serialize()
mem = mem_data.split("_")[0]
mem_addr = mem_data.replace("data", "addr")
mem_en = mem_data.replace("data", "en")
mem_mask = mem_data.replace("data", "mask")
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_en), table) > 0 and \
self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_mask), table) > 0:
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table) is not None:
self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)]\
= self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)
elif stmt.expr.expr.serialize() in self.mem_table:
mem_data = stmt.expr.expr.serialize()
mem = mem_data.split("_")[0]
mem_addr = mem_data.replace("data", "addr")
mem_en = mem_data.replace("data", "en")
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_en), table) > 0:
if self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)] is not None:
self.symbol_table.set_symbol_value(m.name, self.handle_name(mem_data),
self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)], table)
else:
if stmt.expr.get_value(table) is not None:
stmt.loc.set_value(stmt.expr.get_value(table), table)
elif isinstance(stmt, DefNode):
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.name), stmt.value.get_value(table), table)
elif isinstance(stmt, WDefMemory):
for rw in stmt.writers:
mem_table.append(f"{stmt.name}_{rw}_data")
for re in stmt.readers:
mem_table.append(f"{stmt.name}_{re}_data")
elif isinstance(stmt, Block):
for s in stmt.stmts:
self.execute_stmt(s, table)
def execute_module(self, m: Module, ms: Dict[str, DefModule], table=None):
execute_stmts = OrderedDict()
instances = OrderedDict()
def get_in_port_name(name: str, t: Type, d: Direction) -> List[str]:
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
return [name]
elif isinstance(d, Input) and isinstance(t, (VectorType, MemoryType)):
names = []
pnames = get_in_port_name(name, t.typ, d)
for pn in pnames:
for i in range(t.size):
names.append(f"{pn}[{i}]")
return names
elif isinstance(t, BundleType):
names = []
for f in t.fields:
pnames = []
if isinstance(d, Input) and isinstance(f.flip, Default):
pnames += get_in_port_name(f.name, f.typ, d)
elif isinstance(d, Output) and isinstance(f.flip, Flip):
pnames += get_in_port_name(f.name, f.typ, Input())
for pn in pnames:
names.append(f"{name}.{pn}")
return names
else:
return []
def get_out_port_name(name: str, t: Type, d: Direction) -> List[str]:
if isinstance(d, Output) and isinstance(t, (UIntType, SIntType)):
return [name]
elif isinstance(d, Output) and isinstance(t, (VectorType, MemoryType)):
names = []
pnames = get_in_port_name(name, t.typ, d)
for pn in pnames:
for i in range(t.size):
names.append(f"{pn}[{i}]")
return names
elif isinstance(t, BundleType):
names = []
for f in t.fields:
pnames = []
if isinstance(d, Output) and isinstance(f.flip, Default):
pnames += get_out_port_name(f.name, f.typ, d)
elif isinstance(d, Input) and isinstance(f.flip, Flip):
pnames += get_out_port_name(f.name, f.typ, Output())
for pn in pnames:
names.append(f"{name}.{pn}")
return names
else:
return []
def _deal_stmt(s: Statement):
if isinstance(s, Block):
for stmt in s.stmts:
_deal_stmt(stmt)
elif isinstance(s, Connect):
execute_stmts[s.loc.expr.serialize()] = s
elif isinstance(s, DefNode):
execute_stmts[s.name] = s
elif isinstance(s, DefRegister):
self.reg_table[s.name] = s
elif isinstance(s, WDefMemory):
self.mem_table[s.name] = s
elif isinstance(s, DefInstance):
instances[s.name] = s
_deal_stmt(m.body)
for sx in m.body.stmts:
self.execute_stmt(m, sx, table)
for ins in instances:
ref_module_name = instances[ins].module
ref_module = ms[ref_module_name]
ref_table = deepcopy(self.symbol_table.table[ref_module_name])
module_inputs = []
for p in ref_module.ports:
module_inputs += get_in_port_name(p.name, p.typ, p.direction)
ref_inputs = [f"{ins}_{mi}" for mi in module_inputs]
for i in range(len(module_inputs)):
self.symbol_table.set_symbol_value(ref_module_name,
self.handle_name(module_inputs[i]),
self.symbol_table.get_symbol_value(m.name, self.handle_name(ref_inputs[i])),
ref_table)
self.execute_module(ref_module, ms, ref_table)
module_outputs = []
for p in ref_module.ports:
module_outputs += get_out_port_name(p.name, p.typ, p.direction)
ref_outputs = [f"{ins}_{mi}" for mi in module_outputs]
for i in range(len(module_outputs)):
self.symbol_table.set_symbol_value(m.name,
self.handle_name(ref_outputs[i]),
self.symbol_table.get_symbol_value(ref_module_name, self.handle_name(module_outputs[i]), ref_table))
for v in self.dags[m.name].travel_graph(ref_outputs):
if v in execute_stmts:
self.execute_stmt(m, execute_stmts[v], table)
def init_clock(self, table = None):
if table is None:
table = self.symbol_table.clock_table
for mname in table:
if mname not in self.clock_table:
self.clock_table[mname] = {}
for symbol in table[mname]:
self.clock_table[mname][symbol] = SingleClockStepper(mname, symbol, self, table)
def init_executer(self):
AutoName()
self.circuit = CheckHighForm(self.circuit).run()
self.circuit = AutoInferring().run(self.circuit)
self.circuit = CheckTypes().run(self.circuit)
self.circuit = CheckFlow().run(self.circuit)
self.circuit = CheckWidths().run(self.circuit)
self.circuit = ExpandMemory().run(self.circuit)
self.circuit = ReplaceSubaccess().run(self.circuit)
self.circuit = ExpandAggregate().run(self.circuit)
self.circuit = RemoveAccess().run(self.circuit)
self.circuit = ExpandWhens().run(self.circuit)
self.circuit = HandleInstance().run(self.circuit)
self.circuit = Optimize().run(self.circuit)
self.compiler = TesterCompiler(self.symbol_table)
self.compiled_circuit, self.dags = self.compiler.compile(self.circuit)
self.init_clock()
def set_value(self, mname: str, name: str, singal: int):
self.inputchange = True
self.symbol_table.set_symbol_value(mname, self.handle_name(name), singal)
def get_value(self, mname: str, name: str):
if self.inputchange:
self.execute(mname)
self.inputchange = False
return self.symbol_table.get_symbol_value(mname, self.handle_name(name))
def step(self, n: int, mname: str):
if n > 0:
for name in self.clock_table[mname]:
self.clock_table[mname][name].run(n)
def execute(self, mname: str):
ms = {m.name: m for m in self.compiled_circuit.modules}
m = ms[mname]
self.execute_module(m, ms)

View File

@ -0,0 +1,72 @@
from pyhcl.ir.low_ir import *
@dataclass(frozen=True)
class SymbolTable:
table = {}
clock_table = {}
def gen_typ(self, typ: Type):
if isinstance(typ, (AsyncResetType, ResetType, ClockType, UIntType, SIntType)):
return 0
elif isinstance(typ, VectorType):
return [self.gen_typ(typ.typ) for _ in range(typ.size)]
elif isinstance(typ, BundleType):
return {f.name: self.gen_typ(f.typ) for f in typ.fields}
elif isinstance(typ, MemoryType):
return [self.gen_typ(typ.typ) for _ in range(typ.size)]
def has_symbol(self, mname: str, name: str) -> bool:
return name in self.table[mname]
def set_module(self, mname: str):
self.table[mname] = {}
self.clock_table[mname] = {}
def set_symbol(self, mname: str, symbol):
if isinstance(symbol, Port):
if isinstance(symbol.typ, ClockType):
self.clock_table[mname][symbol.name] = self.gen_typ(symbol.typ)
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
if isinstance(symbol, DefWire):
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
elif isinstance(symbol, DefRegister):
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
elif isinstance(symbol, WDefMemory):
self.table[mname][symbol.name] = self.gen_typ(symbol.memType)
for rw in symbol.writers:
self.table[mname][f"{symbol.name}_{rw}_data"] = self.gen_typ(symbol.dataType)
self.table[mname][f"{symbol.name}_{rw}_addr"] = self.gen_typ(UIntType(IntWidth(get_binary_width(symbol.depth))))
self.table[mname][f"{symbol.name}_{rw}_clk"] = self.gen_typ(ClockType())
self.table[mname][f"{symbol.name}_{rw}_en"] = self.gen_typ(UIntType(IntWidth(1)))
self.table[mname][f"{symbol.name}_{rw}_mask"] = self.gen_typ(UIntType(IntWidth(1)))
for rr in symbol.readers:
self.table[mname][f"{symbol.name}_{rr}_data"] = self.gen_typ(symbol.dataType)
self.table[mname][f"{symbol.name}_{rr}_addr"] = self.gen_typ(UIntType(IntWidth(get_binary_width(symbol.depth))))
self.table[mname][f"{symbol.name}_{rr}_clk"] = self.gen_typ(ClockType())
self.table[mname][f"{symbol.name}_{rr}_en"] = self.gen_typ(UIntType(IntWidth(1)))
elif isinstance(symbol, DefNode):
self.table[mname][symbol.name] = self.gen_typ(symbol.value.typ)
elif isinstance(symbol, DefInstance):
for p in symbol.ports:
name = f"{symbol.name}_{p.name}"
self.table[mname][name] = self.gen_typ(p.typ)
else:
...
def get_symbol_value(self, mname: str, names: list, table = None):
if table is None:
table = self.table[mname]
ns = names[:]
while len(ns) > 1:
table = table[ns.pop()]
return table[ns.pop()]
def set_symbol_value(self, mname: str, names: list, signal: int, table = None):
if table is None:
table = self.table[mname]
ns = names[:]
while len(ns) > 1:
table = table[ns.pop()]
table[ns.pop()] = signal
return signal

24
pyhcl/tester/tester.py Normal file
View File

@ -0,0 +1,24 @@
from pyhcl.tester.executer import TesterExecuter
from pyhcl.ir.low_ir import *
from pyhcl.dsl.emitter import Emitter
class Tester:
def __init__(self, m: Module):
circuit = Emitter.elaborate(m)
self.main = circuit.main
self.executer = TesterExecuter(circuit)
self.executer.init_executer()
def poke(self, name: str, value: int):
self.executer.set_value(self.main, name, value)
def peek(self, name: str) -> int:
res = self.executer.get_value(self.main, name)
return int(res, 2) if isinstance(res, str) else res
def expect(self, a, b) -> bool:
return a == b
def step(self, n):
self.executer.step(n, self.main)

131
pyhcl/tester/utils.py Normal file
View File

@ -0,0 +1,131 @@
from collections import OrderedDict, defaultdict
from copy import deepcopy
from pyhcl.tester.exception import TesterException
class DAG:
""" Directed acyclic graph implementation."""
def __init__(self):
self.graph = OrderedDict()
def add_node(self, name: str, graph = None):
if graph is None:
graph = self.graph
if name in graph:
...
else:
graph[name] = set()
def add_node_if_not_exists(self, name: str, graph = None):
try:
self.add_node(name, graph = graph)
except TesterException as e:
raise e
def delete_node(self, name: str, graph = None):
if graph is None:
graph = self.graph
if name not in graph:
raise TesterException(f'node {name} is not exists.')
graph.pop(name)
for _, edges in graph.items():
if name in edges:
edges.remove(name)
def delete_node_if_exists(self, name: str, graph = None):
try:
self.delete_node(name, graph = graph)
except TesterException as e:
raise e
def add_edge(self, ind_node, dep_node, graph = None):
if graph is None:
graph = self.graph
if ind_node not in graph or dep_node not in graph:
raise TesterException(f'nodes do not exist in graph.')
test_graph = deepcopy(graph)
test_graph[ind_node].add(dep_node)
is_valid, msg = self.validate(test_graph)
if is_valid:
graph[ind_node].add(dep_node)
else:
raise TesterException(f'Loop do exist in graph: {msg}')
def delete_edge(self, ind_node, dep_node, graph = None):
if graph is None:
graph = self.graph
if dep_node not in graph.get(ind_node, []):
raise TesterException(f'This edge does not exist in graph')
graph[ind_node].remove(dep_node)
def ind_nodes(self, graph = None):
if graph == None:
graph = self.graph
dep_nodes = set(
node for deps in graph.values() for node in deps
)
return [node for node in graph.keys() if node not in dep_nodes]
def topological_sort(self, graph = None):
if graph is None:
graph = self.graph
result = []
in_degree = defaultdict(lambda: 0)
for u in graph:
for v in graph[u]:
in_degree[v] += 1
ready = [node for node in graph if not in_degree[node]]
while ready:
u = ready.pop()
result.append(u)
for v in graph[u]:
in_degree[v] -= 1
if in_degree[v] == 0:
ready.append(v)
if len(result) == len(graph):
return result
else:
raise TesterException(f'graph is not acyclic.')
def validate(self, graph = None):
if graph is None:
graph = self.graph
if len(self.ind_nodes(graph)) == 0:
return False, 'no independent nodes detected.'
try:
self.topological_sort(graph)
except TesterException:
return False, 'graph is not acyclic.'
return True, 'valid'
def visit_graph(self, graph = None):
visited = []
if graph is None:
graph = self.graph
for v in graph:
for u in graph[v]:
visited.append(f'{v} -> {u}')
return visited
def travel_graph(self, init: list, graph = None):
if graph is None:
graph = self.graph
visits = []
history = set()
while len(init) > 0:
visited = init.pop(0)
if len(graph[visited]) > 0:
for u in graph[visited]:
if u not in init and u not in history:
init.append(u)
history.add(visited)
visits.append(visited)
return visits
def size(self):
return len(self.graph)

154
pyhcl/tester/wir.py Normal file
View File

@ -0,0 +1,154 @@
from pyclbr import Function
from pyhcl.ir.low_ir import *
@dataclass(frozen=True)
class WUIntLiteral(Expression):
expr: Expression
def get_value(self, *args) -> int:
return self.expr.value
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WSIntLiteral(Expression):
expr: Expression
def get_value(self, *args) -> int:
return self.expr.value
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WReference(Expression):
expr: Expression
get_func: Function
set_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def set_value(self, *args) -> int:
return self.set_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WSubField(Expression):
expr: Expression
get_func: Function
set_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def set_value(self, *args) -> int:
return self.set_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WSubIndex(Expression):
expr: Expression
get_func: Function
set_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def set_value(self, *args) -> int:
return self.set_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WSubAccess(Expression):
expr: Expression
get_func: Function
set_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def set_value(self, *args) -> int:
return self.set_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WMux(Expression):
expr: Expression
get_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WValidIf(Expression):
expr: Expression
get_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WDoPrim(Expression):
expr: Expression
get_func: Function
def get_value(self, *args) -> int:
return self.get_func(*args)
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...
@dataclass(frozen=True)
class WInt(Expression):
value: int
def get_value(self, *args) -> int:
return self.value
def serialize(self) -> str:
...
def verilog_serialize(self) -> str:
...