forked from opendacs/PyHCL
commit
6c20d181b8
|
@ -10,6 +10,7 @@ simulation/
|
|||
|
||||
# firrtl
|
||||
.fir/
|
||||
.v/
|
||||
|
||||
# obj
|
||||
obj_dir/*
|
||||
|
@ -122,10 +123,7 @@ venv.bak/
|
|||
|
||||
|
||||
temp
|
||||
|
||||
# jar
|
||||
*.jar
|
||||
main.py
|
||||
|
||||
|
||||
# docs
|
||||
/docs/_site/
|
||||
/docs/Gemfile.lock
|
||||
.jekyll-cache
|
100
README.md
100
README.md
|
@ -9,8 +9,11 @@ and dynamically typed objects.
|
|||
|
||||
The goal of PyHCL is providing a complete design and verification tool flow for heterogeneous computing systems flexibly using the same design methodology.
|
||||
|
||||
PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design. With the FIRRTL
|
||||
compiler framework, PyHCL-generated circuits can be compiled to the widely-used HDL Verilog.
|
||||
PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design.
|
||||
|
||||
PyHCL-generated circuits can be compiled to the widely-used HDL Verilog.
|
||||
|
||||
Attention: The back end of the compilation is highly experimental.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
@ -36,11 +39,11 @@ class FullAdder(Module):
|
|||
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
|
||||
```
|
||||
|
||||
#### Compiling To FIRRTL
|
||||
#### Compiling To High FIRRTL
|
||||
|
||||
Compiling module by calling `compile_to_firrtl`:
|
||||
Compiling module by calling `compile_to_highform`:
|
||||
```python
|
||||
Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir")
|
||||
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
|
||||
```
|
||||
|
||||
Will generate the following FIRRTL codes:
|
||||
|
@ -49,55 +52,74 @@ circuit FullAdder :
|
|||
module FullAdder :
|
||||
input clock : Clock
|
||||
input reset : UInt<1>
|
||||
input FullAdder_io_a : UInt<1>
|
||||
input FullAdder_io_b : UInt<1>
|
||||
input FullAdder_io_cin : UInt<1>
|
||||
output FullAdder_io_sum : UInt<1>
|
||||
output FullAdder_io_cout : UInt<1>
|
||||
output io : {flip a : UInt<1>, flip b : UInt<1>, flip cin : UInt<1>, s : UInt<1>, cout : UInt<1>}
|
||||
|
||||
node _T_0 = xor(FullAdder_io_a, FullAdder_io_b)
|
||||
node _T_1 = xor(_T_0, FullAdder_io_cin)
|
||||
FullAdder_io_sum <= _T_1
|
||||
node _T_2 = and(FullAdder_io_a, FullAdder_io_b)
|
||||
node _T_3 = and(FullAdder_io_b, FullAdder_io_cin)
|
||||
node _T = xor(io.a, io.b)
|
||||
node _T_1 = xor(_T, io.cin)
|
||||
io.s <= _T_1
|
||||
node _T_2 = and(io.a, io.b)
|
||||
node _T_3 = and(io.a, io.cin)
|
||||
node _T_4 = or(_T_2, _T_3)
|
||||
node _T_5 = and(FullAdder_io_a, FullAdder_io_cin)
|
||||
node _T_5 = and(io.b, io.cin)
|
||||
node _T_6 = or(_T_4, _T_5)
|
||||
FullAdder_io_cout <= _T_6
|
||||
io.cout <= _T_6
|
||||
```
|
||||
|
||||
#### Compiling To Lowered FIRRTL
|
||||
|
||||
Compiling module by calling `compile_to_lowform`:
|
||||
|
||||
```python
|
||||
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
|
||||
```
|
||||
|
||||
Will generate the following FIRRTL codes:
|
||||
|
||||
```
|
||||
circuit FullAdder :
|
||||
module FullAdder :
|
||||
input clock : Clock
|
||||
input reset : UInt<1>
|
||||
input io_a : UInt<1>
|
||||
input io_b : UInt<1>
|
||||
input io_cin : UInt<1>
|
||||
output io_s : UInt<1>
|
||||
output io_cout : UInt<1>
|
||||
|
||||
node _T = xor(io_a, io_b)
|
||||
node _T_1 = xor(_T, io_cin)
|
||||
io_s <= _T_1
|
||||
node _T_2 = and(io_a, io_b)
|
||||
node _T_3 = and(io_a, io_cin)
|
||||
node _T_4 = or(_T_2, _T_3)
|
||||
node _T_5 = and(io_b, io_cin)
|
||||
node _T_6 = or(_T_4, _T_5)
|
||||
io_cout <= _T_6
|
||||
```
|
||||
|
||||
#### Compiling To Verilog
|
||||
|
||||
While FIRRTL is generated, PyHCL's job is complete. To further compile to Verilog, the [FIRRTL compiler framework](
|
||||
https://github.com/freechipsproject/firrtl) is required:
|
||||
Compiling module by calling `compile_to_verilog`:
|
||||
|
||||
```shell script
|
||||
Emitter.dumpVerilog(Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir"))
|
||||
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
|
||||
```
|
||||
|
||||
Then `FullAdder.v` will be generated:
|
||||
```verilog
|
||||
module FullAdder(
|
||||
input clock,
|
||||
input reset,
|
||||
input FullAdder_io_a,
|
||||
input FullAdder_io_b,
|
||||
input FullAdder_io_cin,
|
||||
output FullAdder_io_sum,
|
||||
output FullAdder_io_cout
|
||||
input clock,
|
||||
input reset,
|
||||
input io_a,
|
||||
input io_b,
|
||||
input io_cin,
|
||||
output io_s,
|
||||
output io_cout
|
||||
);
|
||||
wire _T_0;
|
||||
wire _T_2;
|
||||
wire _T_3;
|
||||
wire _T_4;
|
||||
wire _T_5;
|
||||
assign _T_0 = FullAdder_io_a ^ FullAdder_io_b;
|
||||
assign _T_2 = FullAdder_io_a & FullAdder_io_b;
|
||||
assign _T_3 = FullAdder_io_b & FullAdder_io_cin;
|
||||
assign _T_4 = _T_2 | _T_3;
|
||||
assign _T_5 = FullAdder_io_a & FullAdder_io_cin;
|
||||
assign FullAdder_io_sum = _T_0 ^ FullAdder_io_cin;
|
||||
assign FullAdder_io_cout = _T_4 | _T_5;
|
||||
|
||||
assign io_s = io_a ^ io_b ^ io_cin;
|
||||
assign io_cout = io_a & io_b | io_a & io_cin | io_b & io_cin;
|
||||
|
||||
endmodule
|
||||
```
|
||||
|
||||
|
|
|
@ -15,4 +15,4 @@ class FullAdder(Module):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Emitter.dumpVerilog(Emitter.dump(Emitter.emit(FullAdder()), "FullAdder.fir"))
|
||||
Emitter.dump(Emitter.emit(FullAdder(), True), "FullAdder.v")
|
||||
|
|
73
main.py
73
main.py
|
@ -1,61 +1,24 @@
|
|||
from pyhcl import *
|
||||
|
||||
W = 8 # 位宽
|
||||
class FullAdder(Module):
|
||||
io = IO(
|
||||
a=Input(Bool),
|
||||
b=Input(Bool),
|
||||
cin=Input(Bool),
|
||||
sum=Output(Bool),
|
||||
cout=Output(Bool),
|
||||
)
|
||||
|
||||
# Generate the sum
|
||||
io.sum <<= io.a ^ io.b ^ io.cin
|
||||
|
||||
def matrixMul(x: int, y: int, z: int):
|
||||
"""
|
||||
x*y × y*z 矩阵乘法电路
|
||||
"""
|
||||
|
||||
class MatrixMul(Module):
|
||||
io = IO(
|
||||
a=Input(Vec(x, Vec(y, U.w(W)))),
|
||||
b=Input(Vec(y, Vec(z, U.w(W)))),
|
||||
o=Output(Vec(x, Vec(z, U.w(W)))),
|
||||
)
|
||||
|
||||
for i, a_row in enumerate(io.a):
|
||||
for j, b_col in enumerate(zip(*io.b)):
|
||||
io.o[i][j] <<= Sum(a * b for a, b in zip(a_row, b_col))
|
||||
|
||||
return MatrixMul()
|
||||
|
||||
|
||||
def bias(n):
|
||||
return U.w(W)(n)
|
||||
|
||||
|
||||
def weight(lst):
|
||||
return VecInit(U.w(W)(i) for i in lst)
|
||||
|
||||
|
||||
def neurons(w, b):
|
||||
"""
|
||||
参数:权重向量 w,偏移量 b
|
||||
输出:神经网络神经元电路 *注:暂无通过非线性传递函数
|
||||
"""
|
||||
|
||||
class Unit(Module):
|
||||
io = IO(
|
||||
i=Input(Vec(len(w), U.w(W))),
|
||||
o=Output(U.w(W))
|
||||
)
|
||||
m = matrixMul(1, len(w), 1).io
|
||||
m.a <<= io.i
|
||||
|
||||
m.b <<= w
|
||||
io.o <<= m.o[0][0] + b
|
||||
|
||||
return Unit()
|
||||
|
||||
|
||||
def main():
|
||||
# 得到权重向量为[3, 4, 5, 6, 7, 8, 9, 10],偏移量为14的神经元电路
|
||||
n = neurons(weight([3, 4, 5, 6, 7, 8, 9, 10]), bias(14))
|
||||
f = Emitter.dump(Emitter.emit(n), "neurons.fir")
|
||||
#Emitter.dumpVerilog(f)
|
||||
|
||||
# Generate the carry
|
||||
io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# emit high firrtl
|
||||
Emitter.dump(Emitter.emit(FullAdder(), HighForm), "FullAdder.fir")
|
||||
# emit lowered firrtl
|
||||
Emitter.dump(Emitter.emit(FullAdder(), LowForm), "FullAdder.lo.fir")
|
||||
# emit verilog
|
||||
Emitter.dump(Emitter.emit(FullAdder(), Verilog), "FullAdder.v")
|
||||
|
|
|
@ -10,3 +10,4 @@ from .funcs import CatVecL2H, CatVecH2L, CatBits, OneDimensionalization, Sum, De
|
|||
from .memory import Mem
|
||||
from .clockdomin import clockdomin
|
||||
from .verifaction import doAssert, doAssume, doCover
|
||||
from .stage import Form, HighForm, MidForm, LowForm, Verilog
|
||||
|
|
|
@ -7,20 +7,10 @@ from pyhcl.core._emit_context import EmitterContext
|
|||
from pyhcl.dsl.module import Module
|
||||
from pyhcl.ir import low_ir
|
||||
from pyhcl.util.firrtltools import replacewithfirmod
|
||||
from pyhcl.dsl.stage import Form, HighForm
|
||||
|
||||
|
||||
class Emitter:
|
||||
# 传入模块对象,返回str---firrtl代码
|
||||
@staticmethod
|
||||
def emit(m: Module, toverilog=False) -> str:
|
||||
circuit = Emitter.elaborate(m)
|
||||
# 将Circuit对象转化为str
|
||||
if(toverilog):
|
||||
return circuit.verilog_serialize()
|
||||
else:
|
||||
return circuit.serialize() # firrtl代码
|
||||
|
||||
# 传入模块对象,返回Circuit对象
|
||||
@staticmethod
|
||||
def elaborate(m: Module) -> low_ir.Circuit:
|
||||
ec: EmitterContext = EmitterContext(m, {}, Counter())
|
||||
|
@ -30,19 +20,39 @@ class Emitter:
|
|||
DynamicContext.clearScope()
|
||||
return circuit
|
||||
|
||||
# 传入firrtl代码和文件名,将firrtl代码写入文件中,并返回文件路径
|
||||
@staticmethod
|
||||
def emit(m: Module, f: Form = HighForm) -> str:
|
||||
return f(Emitter.elaborate(m)).emit()
|
||||
|
||||
@staticmethod
|
||||
def dump(s, filename) -> str:
|
||||
if not os.path.exists('.fir'):
|
||||
os.mkdir('.fir')
|
||||
dir_name = "." + filename.split(".")[-1]
|
||||
if not os.path.exists(dir_name):
|
||||
os.mkdir(dir_name)
|
||||
|
||||
f = os.path.join('.fir', filename)
|
||||
f = os.path.join(dir_name, filename)
|
||||
with open(f, "w+") as fir_file:
|
||||
fir_file.write(s)
|
||||
|
||||
return f
|
||||
|
||||
# 传入firrtl文件路径,执行firrtl命令,将firrtl代码编译为verilog代码
|
||||
@staticmethod
|
||||
def dumpVerilog(filename):
|
||||
os.system('firrtl -i %s -o %s -X verilog' % (filename, filename))
|
||||
def dumpVerilog(filename, use_jar=False):
|
||||
if use_jar:
|
||||
os.system('java -jar firrtl.jar -i %s -o %s -X verilog' % (filename, filename))
|
||||
else:
|
||||
os.system('firrtl -i %s -o %s -X verilog' % (filename, filename))
|
||||
|
||||
@staticmethod
|
||||
def dumpMidForm(filename, use_jar=False):
|
||||
if use_jar:
|
||||
os.system('java -jar firrtl.jar -i %s -o %s -X middle' % (filename, filename))
|
||||
else:
|
||||
os.system('firrtl -i %s -o %s -X middle' % (filename, filename))
|
||||
|
||||
@staticmethod
|
||||
def dumpLoweredForm(filename, use_jar):
|
||||
if use_jar:
|
||||
os.system('java -jar firrtl.jar -i %s -o %s -X low' % (filename, filename))
|
||||
else:
|
||||
os.system('firrtl -i %s -o %s -X low' % (filename, filename))
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from abc import ABC, abstractclassmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pyhcl.ir import low_ir
|
||||
from pyhcl.passes.check_form import CheckHighForm
|
||||
from pyhcl.passes.check_types import CheckTypes
|
||||
from pyhcl.passes.check_flows import CheckFlow
|
||||
from pyhcl.passes.check_widths import CheckWidths
|
||||
from pyhcl.passes.auto_inferring import AutoInferring
|
||||
from pyhcl.passes.replace_subaccess import ReplaceSubaccess
|
||||
from pyhcl.passes.expand_aggregate import ExpandAggregate
|
||||
from pyhcl.passes.expand_whens import ExpandWhens
|
||||
from pyhcl.passes.expand_memory import ExpandMemory
|
||||
from pyhcl.passes.optimize import Optimize
|
||||
from pyhcl.passes.verilog_optimize import VerilogOptimize
|
||||
from pyhcl.passes.remove_access import RemoveAccess
|
||||
from pyhcl.passes.expand_sequential import ExpandSequential
|
||||
from pyhcl.passes.handle_instance import HandleInstance
|
||||
from pyhcl.passes.utils import AutoName
|
||||
|
||||
class Form(ABC):
|
||||
@abstractclassmethod
|
||||
def emit(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass
|
||||
class HighForm(Form):
|
||||
c: low_ir.Circuit
|
||||
|
||||
def emit(self) -> str:
|
||||
self.c = CheckHighForm(self.c).run()
|
||||
self.c = AutoInferring().run(self.c)
|
||||
self.c = CheckTypes().run(self.c)
|
||||
self.c = CheckFlow().run(self.c)
|
||||
self.c = CheckWidths().run(self.c)
|
||||
return self.c.serialize()
|
||||
|
||||
@dataclass
|
||||
class MidForm(Form):
|
||||
def emit(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass
|
||||
class LowForm(Form):
|
||||
c: low_ir.Circuit
|
||||
|
||||
def emit(self) -> str:
|
||||
AutoName()
|
||||
self.c = CheckHighForm(self.c).run()
|
||||
self.c = AutoInferring().run(self.c)
|
||||
self.c = CheckTypes().run(self.c)
|
||||
self.c = CheckFlow().run(self.c)
|
||||
self.c = CheckWidths().run(self.c)
|
||||
self.c = ExpandMemory().run(self.c)
|
||||
self.c = ReplaceSubaccess().run(self.c)
|
||||
self.c = ExpandAggregate().run(self.c)
|
||||
self.c = RemoveAccess().run(self.c)
|
||||
self.c = ExpandWhens().run(self.c)
|
||||
self.c = HandleInstance().run(self.c)
|
||||
self.c = Optimize().run(self.c)
|
||||
return self.c.serialize()
|
||||
|
||||
@dataclass
|
||||
class Verilog(Form):
|
||||
c: low_ir.Circuit
|
||||
|
||||
def emit(self) -> str:
|
||||
AutoName()
|
||||
self.c = CheckHighForm(self.c).run()
|
||||
self.c = AutoInferring().run(self.c)
|
||||
self.c = CheckTypes().run(self.c)
|
||||
self.c = CheckFlow().run(self.c)
|
||||
self.c = CheckWidths().run(self.c)
|
||||
self.c = ExpandAggregate().run(self.c)
|
||||
self.c = ReplaceSubaccess().run(self.c)
|
||||
self.c = RemoveAccess().run(self.c)
|
||||
self.c = VerilogOptimize().run(self.c)
|
||||
self.c = ExpandSequential().run(self.c)
|
||||
self.c = HandleInstance().run(self.c)
|
||||
self.c = Optimize().run(self.c)
|
||||
return self.c.verilog_serialize()
|
|
@ -84,7 +84,7 @@ class VecInit(Node):
|
|||
|
||||
# Connect Elements
|
||||
for i, node in enumerate(self.lst):
|
||||
for idx, elem in self.subIdxs(low_ir.SubIndex(ref, i, typ.typ), node, ctx):
|
||||
for idx, elem in self.subIdxs(low_ir.SubIndex('', ref, i, typ.typ), node, ctx):
|
||||
con = low_ir.Connect(idx, elem)
|
||||
ctx.appendFinalStatement(con, self.scopeId)
|
||||
|
||||
|
@ -93,7 +93,7 @@ class VecInit(Node):
|
|||
def subIdxs(self, idx, node, ctx):
|
||||
if isinstance(node, VecInit):
|
||||
return [(nIdx, elem) for i, n in enumerate(node.lst)
|
||||
for nIdx, elem in node.subIdxs(low_ir.SubIndex(idx, i, node.typ.mapToIR(ctx)), n, ctx)]
|
||||
for nIdx, elem in node.subIdxs(low_ir.SubIndex('', idx, i, node.typ.mapToIR(ctx)), n, ctx)]
|
||||
else:
|
||||
return [(idx, node.mapToIR(ctx))]
|
||||
|
||||
|
|
|
@ -2,13 +2,11 @@ from __future__ import annotations
|
|||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pyhcl.ir.low_node import FirrtlNode
|
||||
from pyhcl.ir.low_prim import PrimOp, Bits
|
||||
from pyhcl.ir.utils import backspace, indent, deleblankline, backspace1
|
||||
|
||||
|
||||
from pyhcl.ir.low_prim import PrimOp, Bits, Cat
|
||||
from pyhcl.ir.utils import backspace, indent, deleblankline, backspace1, get_binary_width, TransformException, DAG
|
||||
class Info(FirrtlNode, ABC):
|
||||
"""INFOs"""
|
||||
def serialize(self) -> str:
|
||||
|
@ -58,7 +56,11 @@ class Expression(FirrtlNode, ABC):
|
|||
|
||||
class Type(FirrtlNode, ABC):
|
||||
"""TYPEs"""
|
||||
...
|
||||
def map_type(self, typ):
|
||||
...
|
||||
|
||||
def map_width(self, typ):
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
@ -69,7 +71,6 @@ class UnknownType(Type):
|
|||
def verilog_serialize(self) -> str:
|
||||
return self.serialize()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Reference(Expression):
|
||||
name: str
|
||||
|
@ -92,11 +93,12 @@ class SubField(Expression):
|
|||
return f"{self.expr.serialize()}.{self.name}"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f"{self.expr.verilog_serialize()}_{self.name}"
|
||||
return self.serialize()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SubIndex(Expression):
|
||||
name: str
|
||||
expr: Expression
|
||||
value: int
|
||||
typ: Type
|
||||
|
@ -105,7 +107,7 @@ class SubIndex(Expression):
|
|||
return f"{self.expr.serialize()}[{self.value}]"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f"{self.expr.verilog_serialize()}[{self.value}]"
|
||||
return self.serialize()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -117,9 +119,8 @@ class SubAccess(Expression):
|
|||
def serialize(self) -> str:
|
||||
return f"{self.expr.serialize()}[{self.index.serialize()}]"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f"{self.expr.verilog_serialize()}[{self.index.verilog_serialize()}]"
|
||||
|
||||
def verilog_serialize(self):
|
||||
return self.serialize()
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mux(Expression):
|
||||
|
@ -134,6 +135,18 @@ class Mux(Expression):
|
|||
def verilog_serialize(self) -> str:
|
||||
return f"{self.cond.verilog_serialize()} ? {self.tval.verilog_serialize()} : {self.fval.verilog_serialize()}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidIf(Expression):
|
||||
cond: Expression
|
||||
value: Expression
|
||||
typ: Type
|
||||
|
||||
def serialize(self) -> str:
|
||||
return f"validif({self.cond.serialize()}, {self.value.serialize()})"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f"{self.cond.verilog_serialize()} ? {self.value.verilog_serialize()} : Z"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DoPrim(Expression):
|
||||
|
@ -148,7 +161,18 @@ class DoPrim(Expression):
|
|||
|
||||
def verilog_serialize(self) -> str:
|
||||
sl: List[str] = [arg.verilog_serialize() for arg in self.args] + [repr(con) for con in self.consts]
|
||||
return f'{self.op.verilog_serialize().join(sl)}'
|
||||
if isinstance(self.op, Cat):
|
||||
return f'{{{", ".join(sl)}}}'
|
||||
elif isinstance(self.op, Bits):
|
||||
arg = self.args[0]
|
||||
msb, lsb = self.consts[0], self.consts[1]
|
||||
msb_lsb = f'{msb}: {lsb}' if msb != lsb else f'{msb}'
|
||||
return f'{arg.verilog_serialize()}[{msb_lsb}]'
|
||||
|
||||
if len(sl) > 1:
|
||||
return f'{self.op.verilog_serialize().join(sl)}'
|
||||
else:
|
||||
return f'{self.op.verilog_serialize()}{sl.pop()}'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -175,10 +199,10 @@ class UnknownWidth(Width):
|
|||
width: int = None
|
||||
|
||||
def serialize(self) -> str:
|
||||
return ""
|
||||
return ''
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return ""
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
@ -291,8 +315,7 @@ class Flip(Orientation):
|
|||
return 'flip '
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
# no "flip" in verilog, bug here
|
||||
return 'flip'
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -306,7 +329,7 @@ class Field(FirrtlNode):
|
|||
return f'{self.flip.serialize()}{self.name} : {self.typ.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f'{self.flip.verilog_serialize()}{self.typ.verilog_serialize()}\t${self.name}'
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -317,14 +340,7 @@ class BundleType(AggregateType):
|
|||
return '{' + ', '.join([f.serialize() for f in self.fields]) + '}'
|
||||
|
||||
def verilog_serialize(self) -> list:
|
||||
field_list = []
|
||||
for f in self.fields:
|
||||
if type(f.typ.verilog_serialize()) is list:
|
||||
for t in f.typ.verilog_serialize():
|
||||
field_list.append(f'${f.name}{t}')
|
||||
else:
|
||||
field_list.append(f.verilog_serialize())
|
||||
return field_list
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -335,12 +351,13 @@ class VectorType(AggregateType):
|
|||
def serialize(self) -> str:
|
||||
return f'{self.typ.serialize()}[{self.size}]'
|
||||
|
||||
def verilog_serialize(self) -> list:
|
||||
return [f'_{v}' for v in range(self.size)]
|
||||
|
||||
def verilog_serialize(self):
|
||||
return ''
|
||||
|
||||
def irWithIndex(self, index):
|
||||
if isinstance(index, int):
|
||||
return lambda _: {"ir": SubIndex(_, index, self.typ)}
|
||||
return lambda _: {"ir": SubIndex('', _, index, self.typ)}
|
||||
else:
|
||||
return lambda _: {"ir": SubAccess(_, index, self.typ)}
|
||||
|
||||
|
@ -354,7 +371,7 @@ class MemoryType(AggregateType):
|
|||
return f'{self.typ.serialize()}[{self.size}]'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return [f'_{v}' for v in range(self.size)]
|
||||
return f'reg {self.typ.verilog_serialize()} m [0:{self.size-1}];'
|
||||
|
||||
def irWithIndex(self, index):
|
||||
return lambda _: {"ir": lambda name, mem, clk, rw: DefMemPort(name, mem, index, clk, rw), "inPort": True}
|
||||
|
@ -368,7 +385,7 @@ class ClockType(GroundType):
|
|||
return 'Clock'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return ""
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
@ -379,8 +396,7 @@ class ResetType(GroundType):
|
|||
return "UInt<1>"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
# Todo
|
||||
return "Reset"
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
@ -391,8 +407,7 @@ class AsyncResetType(GroundType):
|
|||
return "AsyncReset"
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
# Todo
|
||||
return "AsyncReset"
|
||||
return ''
|
||||
|
||||
|
||||
class Direction(FirrtlNode, ABC):
|
||||
|
@ -434,20 +449,7 @@ class Port(FirrtlNode):
|
|||
return f'{self.direction.serialize()} {self.name} : {self.typ.serialize()}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
if type(self.typ.verilog_serialize()) is str:
|
||||
return f'{self.direction.verilog_serialize()}\t{self.typ.verilog_serialize()}\t{self.name},\n'
|
||||
else:
|
||||
portdeclares = ''
|
||||
seq = self.typ.verilog_serialize()
|
||||
for s in seq:
|
||||
ns = s.replace('$', f'{self.name}_')
|
||||
if "flip" in ns:
|
||||
ns = ns.replace('flip', "")
|
||||
portdeclares += f'{self.direction.verilog_serialize(True)}\t{ns},\n'
|
||||
else:
|
||||
portdeclares += f'{self.direction.verilog_serialize()}\t{ns},\n'
|
||||
return portdeclares
|
||||
|
||||
return f'{self.direction.verilog_serialize()}\t{self.typ.verilog_serialize()}\t{self.name},\t{self.info.verilog_serialize()}'
|
||||
|
||||
class Statement(FirrtlNode, ABC):
|
||||
...
|
||||
|
@ -459,7 +461,7 @@ class EmptyStmt(Statement):
|
|||
return 'skip'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return '// skip'
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -472,7 +474,7 @@ class DefWire(Statement):
|
|||
return f'wire {self.name} : {self.typ.serialize()}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f'wire\t{self.typ.verilog_serialize()}\t{self.name}{self.info.verilog_serialize()};'
|
||||
return f'wire\t{self.typ.verilog_serialize()}\t{self.name};\t{self.info.verilog_serialize()}'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -485,15 +487,16 @@ class DefRegister(Statement):
|
|||
info: Info = NoInfo()
|
||||
|
||||
def serialize(self) -> str:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
|
||||
if self.init is not None else ""
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
|
||||
if self.reset is not None else ""
|
||||
if self.init:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()}, {self.init.serialize()})') \
|
||||
if self.init is not None else ""
|
||||
else:
|
||||
i: str = indent(f' with : \nreset => ({self.reset.serialize()})') \
|
||||
if self.reset is not None else ""
|
||||
return f'reg {self.name} : {self.typ.serialize()}, {self.clock.serialize()}{i}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f'reg {self.typ.verilog_serialize()}\t{self.name}{self.info.verilog_serialize()};'
|
||||
|
||||
return f'reg\t{self.typ.verilog_serialize()}\t{self.name};\t{self.info.verilog_serialize()}'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DefInstance(Statement):
|
||||
|
@ -506,36 +509,40 @@ class DefInstance(Statement):
|
|||
return f'inst {self.name} of {self.module}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
instdeclares = ''
|
||||
instdeclares: List[str] = []
|
||||
portdeclares: List[str] = []
|
||||
for p in self.ports:
|
||||
instdeclares += f'\n.{p.name}({self.name}_{p.name}),'
|
||||
return f'{self.module} {self.name} ({instdeclares}\n);'
|
||||
portdeclares.append(f'wire\t{p.typ.verilog_serialize()}\t{self.name}_{p.name};')
|
||||
instdeclares.append(indent(f'\n.{p.name}({self.name}_{p.name}),'))
|
||||
port_decs = '\n'.join(portdeclares)
|
||||
return f"{port_decs}\n{self.module}\t{self.name}(\t{self.info.verilog_serialize()}{''.join(instdeclares)});"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WDefMemory(Statement):
|
||||
name: str
|
||||
memType: MemoryType
|
||||
dataType: Type
|
||||
depth: int
|
||||
writeLatency: int
|
||||
readLatency: int
|
||||
readers: List[str]
|
||||
writers: List[str]
|
||||
readUnderWrite: Optional[str] = None
|
||||
info: Info = NoInfo()
|
||||
|
||||
# @dataclass(frozen=True)
|
||||
# class DefMemory(Statement):
|
||||
# name: str
|
||||
# dataType: Type
|
||||
# depth: int
|
||||
# writeLatency: int
|
||||
# readLatency: int
|
||||
# readers: List[str]
|
||||
# writers: List[str]
|
||||
# readWriters: List[str]
|
||||
# readUnderWrite: Optional[str] = None
|
||||
# info: Info = NoInfo()
|
||||
#
|
||||
# def serialize(self) -> str:
|
||||
# lst = [f"\ndata-type => {self.dataType.serialize()}",
|
||||
# f"depth => {self.depth}",
|
||||
# f"read-latency => {self.readLatency}",
|
||||
# f"write-latency => {self.writeLatency}"] + \
|
||||
# [f"reader => {r}" for r in self.readers] + \
|
||||
# [f"writer => {w}" for w in self.writers] + \
|
||||
# [f"readwriter => {rw}" for rw in self.readWriters] + \
|
||||
# ['read-under-write => undefined']
|
||||
# s = indent('\n'.join(lst))
|
||||
# return f'mem {self.name} :{self.info.serialize()}{s}'
|
||||
def serialize(self) -> str:
|
||||
lst = [f"\ndata-type => {self.dataType.serialize()}",
|
||||
f"depth => {self.depth}",
|
||||
f"read-latency => {self.readLatency}",
|
||||
f"write-latency => {self.writeLatency}"] + \
|
||||
[f"reader => {r}" for r in self.readers] + \
|
||||
[f"writer => {w}" for w in self.writers] + \
|
||||
['read-under-write => undefined']
|
||||
s = indent('\n'.join(lst))
|
||||
return f'mem {self.name} :{self.info.serialize()}{s}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return self.memType.verilog_serialize()
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DefMemory(Statement):
|
||||
|
@ -547,10 +554,7 @@ class DefMemory(Statement):
|
|||
return f'cmem {self.name} : {self.memType.serialize()}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
memorydeclares = ''
|
||||
memorydeclares += f'{self.memType.verilog_serialize()}\n'
|
||||
memorydeclares += f'{self.info.verilog_serialize()}\n'
|
||||
return memorydeclares
|
||||
return f'{self.memType.verilog_serialize()}\t{self.info.verilog_serialize()}'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -563,7 +567,7 @@ class DefNode(Statement):
|
|||
return f'node {self.name} = {self.value.serialize()}{self.info.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f'wire {self.value.typ.verilog_serialize()} {self.name} = {self.value.verilog_serialize()}{self.info.verilog_serialize()};'
|
||||
return f'wire\t{self.value.typ.verilog_serialize()}\t{self.name} = {self.value.verilog_serialize()};\t{self.info.verilog_serialize()}'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -582,12 +586,15 @@ class DefMemPort(Statement):
|
|||
|
||||
def verilog_serialize(self) -> str:
|
||||
memportdeclares = ''
|
||||
memportdeclares += indent(f'{self.mem.verilog_serialize()}')
|
||||
# TODO
|
||||
memportdeclares += f'wire {self.mem.typ.typ.verilog_serialize()} {self.mem.verilog_serialize()}_{self.name}_data;\n'
|
||||
memportdeclares += f'wire [{get_binary_width(self.mem.typ.size)-1}:0] {self.mem.verilog_serialize()}_{self.name}_addr;\n'
|
||||
memportdeclares += f'wire {self.mem.verilog_serialize()}_{self.name}_en;\n'
|
||||
memportdeclares += f'assign {self.mem.verilog_serialize()}_{self.name}_addr = {get_binary_width(self.mem.typ.size)}\'h{self.index.value};\n'
|
||||
memportdeclares += f'assign {self.mem.verilog_serialize()}_{self.name}_en = 1\'h1;\n'
|
||||
if self.rw is False:
|
||||
memportdeclares += f'wire {self.mem.verilog_serialize()}_{self.name}_mask;\n'
|
||||
return memportdeclares
|
||||
|
||||
|
||||
# stmt pass will change the conseq and alt
|
||||
@dataclass(frozen=False)
|
||||
class Conditionally(Statement):
|
||||
pred: Expression
|
||||
|
@ -602,132 +609,40 @@ class Conditionally(Statement):
|
|||
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
s = indent(f'\n{self.conseq.verilog_serialize()}') + \
|
||||
('' if self.alt == EmptyStmt() else '\nelse' + indent(f'\n{self.alt.verilog_serialize()}'))
|
||||
return f'if ({self.pred.verilog_serialize()}) {self.info.verilog_serialize()}{s}'
|
||||
|
||||
class ValTracer:
|
||||
...
|
||||
|
||||
|
||||
class RegTracer(ValTracer):
|
||||
def __init__(self, reg: DefRegister):
|
||||
self.stmt = reg
|
||||
self.clock = reg.clock
|
||||
self.reset = reg.reset
|
||||
self.conds = {}
|
||||
|
||||
def add_cond(self, signal: str, action):
|
||||
if signal not in self.conds:
|
||||
self.conds[signal] = [action]
|
||||
else:
|
||||
self.conds[signal].append(action)
|
||||
|
||||
def gen_body(self):
|
||||
stmts = []
|
||||
final_map = {}
|
||||
for k, v in self.conds.items():
|
||||
if(k == "default"):
|
||||
stmts.append(Block(self.conds[k]))
|
||||
else:
|
||||
stmts.append(Conditionally(Reference(k, UIntType(IntWidth(1))), Block(self.conds[k]), EmptyStmt()))
|
||||
self.body = Block(stmts)
|
||||
|
||||
def gen_always_block(self):
|
||||
self.gen_body()
|
||||
res = f'always @(posedge {self.clock.verilog_serialize()}) begin\n' + \
|
||||
deleblankline(indent(f'\n{self.body.verilog_serialize()}')) + \
|
||||
f'\nend'
|
||||
return res
|
||||
|
||||
|
||||
class PassManager:
|
||||
def __init__(self, target_block):
|
||||
self.block = target_block
|
||||
|
||||
self.reg_map = {}
|
||||
|
||||
self.define_pass()
|
||||
|
||||
def renew(self):
|
||||
return self.stmts_pass(self.block)
|
||||
|
||||
# find out all reg definition and its clock and reset infos
|
||||
def define_pass(self):
|
||||
reg_tracers = [RegTracer(stmt) for stmt in self.block.stmts if type(stmt) == DefRegister]
|
||||
self.reg_map = {x.stmt.name: x for x in reg_tracers}
|
||||
|
||||
# nest
|
||||
def stmts_pass(self, block, signal: str = "default") -> Block:
|
||||
if type(block) == EmptyStmt:
|
||||
return EmptyStmt()
|
||||
for stmt in block.stmts:
|
||||
if type(stmt) == Connect and stmt.loc.name in self.reg_map:
|
||||
stmt.blocking = False
|
||||
self.reg_map[stmt.loc.name].add_cond(signal, stmt)
|
||||
elif type(stmt) == Conditionally:
|
||||
stmt.conseq = self.stmts_pass(stmt.conseq, stmt.pred.verilog_serialize())
|
||||
stmt.alt = self.stmts_pass(stmt.alt, "!" + stmt.pred.verilog_serialize())
|
||||
else:
|
||||
pass
|
||||
stmts = [stmt for stmt in block.stmts if self.pass_check(stmt)]
|
||||
return Block(stmts) if stmts else EmptyStmt()
|
||||
|
||||
def pass_check(self, stmt):
|
||||
return type(stmt) != Conditionally and type(stmt) != Connect \
|
||||
or type(stmt) == Connect and stmt.loc.name not in self.reg_map \
|
||||
or type(stmt) == Conditionally and type(stmt.conseq) != EmptyStmt
|
||||
|
||||
def gen_all_always_block(self):
|
||||
res = ""
|
||||
for reg, tracer in self.reg_map.items():
|
||||
res += f'\n// handle register {reg}'
|
||||
res += f'\n{tracer.gen_always_block()}'
|
||||
return res
|
||||
|
||||
s = indent(f"\n{self.conseq.verilog_serialize()}") + "\nend" + \
|
||||
("" if self.alt == EmptyStmt() else "\nelse begin" + indent(f"\n{self.alt.verilog_serialize()}") + "\nend")
|
||||
return f"if ({self.pred.verilog_serialize()}) begin\t{self.info.verilog_serialize()}{s}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Block(Statement):
|
||||
stmts: List[Statement]
|
||||
|
||||
"""
|
||||
def serialize(self) -> str:
|
||||
return '\n'.join([stmt.serialize() for stmt in self.stmts]) if self.stmts else ""
|
||||
"""
|
||||
|
||||
def auto_gen_node(self, stmt):
|
||||
return isinstance(stmt, DefNode) and stmt.name.startswith("_T")
|
||||
|
||||
# use less nodes
|
||||
def serialize(self) -> str:
|
||||
if not self.stmts:
|
||||
return ""
|
||||
node_exp_map = {stmt.name: stmt for stmt in self.stmts if self.auto_gen_node(stmt)}
|
||||
|
||||
# replace all reference in node_exp_map
|
||||
for k, v in node_exp_map.items():
|
||||
if isinstance(v.value, DoPrim):
|
||||
args = v.value.args
|
||||
cnt = 0
|
||||
for arg in args:
|
||||
if isinstance(arg, Reference) and arg.name in node_exp_map:
|
||||
node_exp_map[k].value.args[cnt] = node_exp_map[arg.name].value
|
||||
cnt += 1
|
||||
# replace all reference in connect
|
||||
for stmt in self.stmts:
|
||||
if isinstance(stmt, Connect) and isinstance(stmt.expr, Reference) and stmt.expr.name in node_exp_map:
|
||||
stmt.expr = node_exp_map[stmt.expr.name].value
|
||||
return '\n'.join([stmt.serialize() for stmt in self.stmts if not self.auto_gen_node(stmt)]) if self.stmts else ""
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
manager = PassManager(self)
|
||||
new_blocks = manager.renew()
|
||||
always_blocks = manager.gen_all_always_block()
|
||||
return '\n'.join([stmt.verilog_serialize() for stmt in self.stmts])
|
||||
|
||||
return '\n'.join([stmt.verilog_serialize() for stmt in new_blocks.stmts]) + f'\n{always_blocks}' if self.stmts else ""
|
||||
@dataclass(frozen=True)
|
||||
class AlwaysBlock(Statement):
|
||||
stmts: List[Statement]
|
||||
clk: Expression = None
|
||||
|
||||
def serialize(self) -> str:
|
||||
pass
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
cat_table: List[str] = []
|
||||
for stmt in self.stmts:
|
||||
cat_table.append(stmt.verilog_serialize())
|
||||
|
||||
if self.clk is None:
|
||||
declares = '\n' + "\n".join(cat_table)
|
||||
return deleblankline(f"always @(posedge clock) begin\n{indent(declares)}"+ "\nend")
|
||||
else:
|
||||
declares = '\n' + "\n".join(cat_table)
|
||||
return deleblankline(f"always @(posedge {self.clk.verilog_serialize()}) begin\n{indent(declares)}" + "\nend")
|
||||
|
||||
# pass will change the "blocking" feature of Connect Stmt
|
||||
@dataclass(frozen=False)
|
||||
class Connect(Statement):
|
||||
loc: Expression
|
||||
|
@ -735,17 +650,20 @@ class Connect(Statement):
|
|||
info: Info = NoInfo()
|
||||
blocking: bool = True
|
||||
bidirection: bool = False
|
||||
mem: Dict = field(default_factory=dict)
|
||||
|
||||
def serialize(self) -> str:
|
||||
if not self.bidirection:
|
||||
return f'{self.info.serialize()}\n{self.loc.serialize()} <= {self.expr.serialize()}'
|
||||
return f'{self.info.serialize()}{self.loc.serialize()} <= {self.expr.serialize()}'
|
||||
else:
|
||||
return f'{self.info.serialize()}\n{self.loc.serialize()} <= {self.expr.serialize()}\n' + \
|
||||
f'{self.info.serialize()}\n{self.expr.serialize()} <= {self.loc.serialize()}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
op = "=" if self.blocking else "<="
|
||||
return f'assign {self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()}{self.info.verilog_serialize()};'
|
||||
if self.blocking is False:
|
||||
return f'{self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()};\t{self.info.verilog_serialize()}'
|
||||
return f'assign\t{self.loc.verilog_serialize()} {op} {self.expr.verilog_serialize()};\t{self.info.verilog_serialize()}'
|
||||
|
||||
|
||||
# Verification
|
||||
|
@ -802,8 +720,10 @@ class DefModule(FirrtlNode, ABC):
|
|||
return f'{typ} {self.name} :{self.info.serialize()}{moduledeclares}\n'
|
||||
|
||||
def verilog_serializeHeader(self, typ: str) -> str:
|
||||
moduledeclares = ' ' + indent(''.join([f'{p.verilog_serialize()}' for p in self.ports]))
|
||||
return f'{typ} {self.name}(\n{deleblankline(moduledeclares)[:-1]}\n);\n'
|
||||
port_declares: List[str] = []
|
||||
for p in self.ports:
|
||||
port_declares.append(indent("\n"+ p.verilog_serialize()))
|
||||
return f"{typ} {self.name}(\t{self.info.verilog_serialize()}{''.join(port_declares)}\n);\n"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -834,7 +754,7 @@ class ExtModule(DefModule):
|
|||
return f'{self.serializeHeader("extmodule")}{s}'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
return f'{self.verilog_serializeHeader("module")}endmodule\n'
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -844,9 +764,85 @@ class Circuit(FirrtlNode):
|
|||
info: Info = NoInfo()
|
||||
|
||||
def serialize(self) -> str:
|
||||
CheckCombLoop()
|
||||
ms = '\n'.join([indent(f'\n{m.serialize()}') for m in self.modules])
|
||||
return f'circuit {self.main} :{self.info.serialize()}{ms}\n'
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
self.requires()
|
||||
ms = ''.join([f'{m.verilog_serialize()}\n' for m in self.modules])
|
||||
return ms
|
||||
|
||||
def requires(self):
|
||||
CheckCombLoop()
|
||||
|
||||
class CheckCombLoop:
|
||||
connect_graph = DAG()
|
||||
reg_map = {}
|
||||
|
||||
@staticmethod
|
||||
def run(stmt: Statement):
|
||||
def check_comb_loop_e(u: Expression, v: Expression):
|
||||
if isinstance(u, (Reference, SubField, SubIndex, SubAccess)):
|
||||
ux, vx = u.serialize(), v.serialize()
|
||||
if ux in CheckCombLoop.reg_map or vx in CheckCombLoop.reg_map:
|
||||
return
|
||||
try:
|
||||
CheckCombLoop.connect_graph.add_node_if_not_exists(vx)
|
||||
CheckCombLoop.connect_graph.add_node_if_not_exists(ux)
|
||||
CheckCombLoop.connect_graph.add_edge(ux, vx)
|
||||
except TransformException as e:
|
||||
raise e
|
||||
elif isinstance(u, Mux):
|
||||
check_comb_loop_e(u.tval, v)
|
||||
check_comb_loop_e(u.fval, v)
|
||||
elif isinstance(u, ValidIf):
|
||||
check_comb_loop_e(u.value, v)
|
||||
elif isinstance(u, DoPrim):
|
||||
for arg in u.args:
|
||||
check_comb_loop_e(arg, v)
|
||||
else:
|
||||
...
|
||||
|
||||
def check_comb_loop_s(s: Statement):
|
||||
if isinstance(s, Connect):
|
||||
if isinstance(s.loc, (Reference, SubField, SubIndex, SubAccess)):
|
||||
check_comb_loop_e(s.expr, s.loc)
|
||||
elif isinstance(s, DefNode):
|
||||
check_comb_loop_e(s.value, Reference(s.name, s.value.typ))
|
||||
elif isinstance(s, Conditionally):
|
||||
check_comb_loop_s(s.conseq)
|
||||
check_comb_loop_s(s.alt)
|
||||
elif isinstance(s, EmptyStmt):
|
||||
...
|
||||
elif isinstance(s, Block):
|
||||
for stmt in s.stmts:
|
||||
check_comb_loop_s(stmt)
|
||||
|
||||
def get_reg_map(s: Statement):
|
||||
if isinstance(s, Block):
|
||||
for sx in s.stmts:
|
||||
if isinstance(sx, DefRegister):
|
||||
CheckCombLoop.reg_map[sx.name] = sx
|
||||
elif isinstance(sx, DefMemPort):
|
||||
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_data"] = DefWire(f"{sx.mem.name}_{sx.name}_data",
|
||||
UIntType(IntWidth(sx.mem.typ.size)))
|
||||
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_addr"] = DefWire(f"{sx.mem.name}_{sx.name}_addr",
|
||||
UIntType(IntWidth(get_binary_width(sx.mem.typ.size))))
|
||||
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_en"] = DefWire(f"{sx.mem.name}_{sx.name}_en",
|
||||
UIntType(IntWidth(1)))
|
||||
if sx.rw is False:
|
||||
CheckCombLoop.reg_map[f"{sx.mem.name}_{sx.name}_mask"] = DefWire(f"{sx.mem.name}_{sx.name}_mask",
|
||||
UIntType(IntWidth(1)))
|
||||
else:
|
||||
...
|
||||
elif isinstance(s, EmptyStmt):
|
||||
...
|
||||
elif isinstance(s, Conditionally):
|
||||
get_reg_map(s.conseq)
|
||||
get_reg_map(s.alt)
|
||||
|
||||
get_reg_map(stmt)
|
||||
check_comb_loop_s(stmt)
|
||||
|
||||
return stmt
|
||||
|
|
|
@ -220,7 +220,7 @@ class Not(PrimOp):
|
|||
return 'not'
|
||||
|
||||
def verilog_op(self):
|
||||
return " ~ "
|
||||
return "!"
|
||||
|
||||
|
||||
# Bitwise And
|
||||
|
@ -280,6 +280,9 @@ class Cat(PrimOp):
|
|||
def __repr__(self):
|
||||
return 'cat'
|
||||
|
||||
def verilog_op(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
# Bit Extraction
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
@ -287,6 +290,9 @@ class Bits(PrimOp):
|
|||
def __repr__(self):
|
||||
return 'bits'
|
||||
|
||||
def verilog_op(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
# Head
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
def indent(string: str) -> str:
|
||||
return string.replace('\n', '\n ')
|
||||
import math
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import copy, deepcopy
|
||||
|
||||
def indent(string: str, space: int = 1) -> str:
|
||||
return string.replace('\n', '\n' + '\t' * space)
|
||||
|
||||
|
||||
def backspace(string: str) -> str:
|
||||
|
@ -34,3 +38,130 @@ def auto_connect(ma, mb):
|
|||
io_left <<= io_right
|
||||
else:
|
||||
io_right <<= io_left
|
||||
|
||||
def get_binary_width(target):
|
||||
width = 1
|
||||
while target / 2 >= 1:
|
||||
width += 1
|
||||
target = math.floor(target / 2)
|
||||
return width
|
||||
|
||||
class TransformException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
class DAG:
|
||||
""" Directed acyclic graph implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self.graph = OrderedDict()
|
||||
|
||||
def add_node(self, name: str, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if name in graph:
|
||||
...
|
||||
else:
|
||||
graph[name] = set()
|
||||
|
||||
def add_node_if_not_exists(self, name: str, graph = None):
|
||||
try:
|
||||
self.add_node(name, graph = graph)
|
||||
except TransformException as e:
|
||||
raise e
|
||||
|
||||
def delete_node(self, name: str, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if name not in graph:
|
||||
raise TransformException(f'node {name} is not exists.')
|
||||
graph.pop(name)
|
||||
for _, edges in graph.items():
|
||||
if name in edges:
|
||||
edges.remove(name)
|
||||
|
||||
def delete_node_if_exists(self, name: str, graph = None):
|
||||
try:
|
||||
self.delete_node(name, graph = graph)
|
||||
except TransformException as e:
|
||||
raise e
|
||||
|
||||
def add_edge(self, ind_node, dep_node, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if ind_node not in graph or dep_node not in graph:
|
||||
raise TransformException(f'nodes do not exist in graph.')
|
||||
test_graph = deepcopy(graph)
|
||||
test_graph[ind_node].add(dep_node)
|
||||
is_valid, msg = self.validate(test_graph)
|
||||
if is_valid:
|
||||
graph[ind_node].add(dep_node)
|
||||
else:
|
||||
raise TransformException(f'Loop do exist in graph: {msg}')
|
||||
|
||||
def delete_edge(self, ind_node, dep_node, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if dep_node not in graph.get(ind_node, []):
|
||||
raise TransformException(f'This edge does not exist in graph')
|
||||
graph[ind_node].remove(dep_node)
|
||||
|
||||
def ind_nodes(self, graph = None):
|
||||
if graph == None:
|
||||
graph = self.graph
|
||||
dep_nodes = set(
|
||||
node for deps in graph.values() for node in deps
|
||||
)
|
||||
|
||||
return [node for node in graph.keys() if node not in dep_nodes]
|
||||
|
||||
def topological_sort(self, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
result = []
|
||||
in_degree = defaultdict(lambda: 0)
|
||||
|
||||
for u in graph:
|
||||
for v in graph[u]:
|
||||
in_degree[v] += 1
|
||||
|
||||
ready = [node for node in graph if not in_degree[node]]
|
||||
|
||||
while ready:
|
||||
u = ready.pop()
|
||||
result.append(u)
|
||||
for v in graph[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0:
|
||||
ready.append(v)
|
||||
if len(result) == len(graph):
|
||||
return result
|
||||
else:
|
||||
raise TransformException(f'graph is not acyclic.')
|
||||
|
||||
|
||||
def validate(self, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if len(self.ind_nodes(graph)) == 0:
|
||||
return False, 'no independent nodes detected.'
|
||||
try:
|
||||
self.topological_sort(graph)
|
||||
except TransformException:
|
||||
return False, 'graph is not acyclic.'
|
||||
return True, 'valid'
|
||||
|
||||
def visit_graph(self, graph = None):
|
||||
visited = []
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
for v in graph:
|
||||
for u in graph[v]:
|
||||
visited.append(f'{v} -> {u}')
|
||||
return visited
|
||||
|
||||
def size(self):
|
||||
return len(self.graph)
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
from typing import List
|
||||
|
||||
class Pass:
|
||||
...
|
||||
|
||||
#Error handling
|
||||
class PassException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class PassExceptions(Exception):
|
||||
def __init__(self, exceptions: List[PassException]):
|
||||
self.message = '\n'.join([str(exception) for exception in exceptions])
|
||||
|
||||
def __str__(self):
|
||||
return '\n' + self.message
|
||||
|
||||
|
||||
class Error:
|
||||
def __init__(self):
|
||||
self.errors: List[PassException] = []
|
||||
|
||||
def append(self, pe: PassException):
|
||||
self.errors.append(pe)
|
||||
|
||||
def trigger(self):
|
||||
if len(self.errors) == 0:
|
||||
return
|
||||
elif len(self.errors) == 1:
|
||||
raise self.errors.pop()
|
||||
else:
|
||||
self.append(f'{len(self.errors)} errors detected!')
|
||||
raise PassExceptions(self.errors)
|
|
@ -0,0 +1,128 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from pyhcl.ir.low_ir import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoInferring:
|
||||
max_width: int = 0
|
||||
|
||||
def run(self, c: Circuit):
|
||||
modules: List[Module] = []
|
||||
|
||||
def auto_inferring_t(t: Type) -> Type:
|
||||
if isinstance(t, UIntType):
|
||||
if t.width.width == 0:
|
||||
return UIntType(IntWidth(self.max_width))
|
||||
else:
|
||||
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
|
||||
return t
|
||||
elif isinstance(t, SIntType):
|
||||
if t.width.width == 0:
|
||||
return SIntType(IntWidth(self.max_width))
|
||||
else:
|
||||
self.max_width = self.max_width if self.max_width > t.width.width else t.width.width
|
||||
return t
|
||||
elif isinstance(t, (ClockType, ResetType, AsyncResetType)):
|
||||
return t
|
||||
elif isinstance(t, VectorType):
|
||||
return VectorType(auto_inferring_t(t.typ), t.size)
|
||||
elif isinstance(t, MemoryType):
|
||||
return MemoryType(auto_inferring_t(t.typ), t.size)
|
||||
elif isinstance(t, BundleType):
|
||||
return BundleType([Field(fx.name, fx.flip, auto_inferring_t(fx.typ)) for fx in t.fields])
|
||||
else:
|
||||
return t
|
||||
|
||||
def auto_inferring_e(e: Expression, inferring_map: Dict[str, Type]) -> Expression:
|
||||
if isinstance(e, Mux):
|
||||
return Mux(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.tval, inferring_map),
|
||||
auto_inferring_e(e.fval, inferring_map), auto_inferring_t(e.typ))
|
||||
elif isinstance(e, ValidIf):
|
||||
return ValidIf(auto_inferring_e(e.cond, inferring_map), auto_inferring_e(e.value, inferring_map), auto_inferring_t(e.typ))
|
||||
elif isinstance(e, DoPrim):
|
||||
return DoPrim(e.op, [auto_inferring_e(arg, inferring_map) for arg in e.args], e.consts, auto_inferring_t(e.typ))
|
||||
elif isinstance(e, UIntLiteral):
|
||||
if e.width.width < get_binary_width(e.value):
|
||||
return UIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
|
||||
else:
|
||||
return e
|
||||
elif isinstance(e, SIntLiteral):
|
||||
if e.width.width < get_binary_width(e.value) + 1:
|
||||
return SIntLiteral(e.value, IntWidth(get_binary_width(e.value)))
|
||||
else:
|
||||
return e
|
||||
elif isinstance(e, Reference):
|
||||
typ = inferring_map[e.name] if not isinstance(inferring_map[e.name], UnknownType) else auto_inferring_t(e.typ)
|
||||
return Reference(e.name, typ)
|
||||
elif isinstance(e, SubField):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
typ = e.typ
|
||||
for fx in expr.typ.fields:
|
||||
if fx.name == e.name:
|
||||
typ = fx.typ
|
||||
return SubField(expr, e.name, typ)
|
||||
elif isinstance(e, SubIndex):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
return SubIndex(e.name, expr, e.value, expr.typ.typ)
|
||||
elif isinstance(e, SubAccess):
|
||||
expr = auto_inferring_e(e.expr, inferring_map)
|
||||
index = auto_inferring_e(e.index, inferring_map)
|
||||
return SubAccess(expr, index, expr.typ.typ)
|
||||
else:
|
||||
return e
|
||||
|
||||
def auto_inferring_s(s: Statement, inferring_map: Dict[str, Type]) -> Statement:
|
||||
if isinstance(s, Block):
|
||||
stmts: List[Statement] = []
|
||||
for sx in s.stmts:
|
||||
stmts.append(auto_inferring_s(sx, inferring_map))
|
||||
return Block(stmts)
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(auto_inferring_e(s.pred, inferring_map), auto_inferring_s(s.conseq, inferring_map), auto_inferring_s(s.alt, inferring_map), s.info)
|
||||
elif isinstance(s, DefRegister):
|
||||
clock = auto_inferring_e(s.clock, inferring_map)
|
||||
reset = auto_inferring_e(s.reset, inferring_map)
|
||||
init = auto_inferring_e(s.init, inferring_map)
|
||||
typ = auto_inferring_t(s.typ)
|
||||
inferring_map[s.name] = typ
|
||||
return DefRegister(s.name, typ, clock, reset, init, s.info)
|
||||
elif isinstance(s, DefWire):
|
||||
inferring_map[s.name] = auto_inferring_t(s.typ)
|
||||
return s
|
||||
elif isinstance(s, DefMemory):
|
||||
inferring_map[s.name] = auto_inferring_t(s.memType)
|
||||
return s
|
||||
elif isinstance(s, DefNode):
|
||||
value = auto_inferring_e(s.value, inferring_map)
|
||||
inferring_map[s.name] = value.typ
|
||||
return DefNode(s.name, value, s.info)
|
||||
elif isinstance(s, DefMemPort):
|
||||
clk = auto_inferring_e(s.clk, inferring_map)
|
||||
index = auto_inferring_e(s.index, inferring_map)
|
||||
inferring_map[s.name] = UnknownType()
|
||||
return DefMemPort(s.name, s.mem, index, clk, s.rw, s.info)
|
||||
elif isinstance(s, DefInstance):
|
||||
inferring_map[s.name] = UnknownType()
|
||||
return s
|
||||
elif isinstance(s, Connect):
|
||||
return Connect(auto_inferring_e(s.loc, inferring_map), auto_inferring_e(s.expr, inferring_map), s.info, s.blocking, s.bidirection, s.mem)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def auto_inferring_m(m: DefModule, inferring_map: Dict[str, Type]) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
ports: List[Port] = []
|
||||
for p in m.ports:
|
||||
inferring_map[p.name] = auto_inferring_t(p.typ)
|
||||
ports.append(Port(p.name, p.direction, inferring_map[p.name], p.info))
|
||||
body = auto_inferring_s(m.body, inferring_map)
|
||||
return Module(m.name, ports, body, m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
inferring_map: Dict[str, Type] = {}
|
||||
modules.append(auto_inferring_m(m, inferring_map))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,106 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass, PassException, Error
|
||||
from pyhcl.passes.utils import times_f_f, times_g_flip, to_flow
|
||||
from pyhcl.passes.wir import *
|
||||
|
||||
class WrongFlow(PassException):
|
||||
def __init__(self, info: Info, mname: str, expr: str, wrong: Flow, right: Flow):
|
||||
super().__init__(f'{info}: [module {mname}] Expression {expr} is used as a {wrong} but can only be used as a {right}.')
|
||||
|
||||
class CheckFlow(Pass):
|
||||
def run(self, c: Circuit):
|
||||
errors = Error()
|
||||
|
||||
def get_flow(e: Expression, flows: Dict[str, Flow]) -> Flow:
|
||||
if isinstance(e, Reference):
|
||||
return flows[e.name]
|
||||
elif isinstance(e, SubIndex):
|
||||
return get_flow(e.expr, flows)
|
||||
elif isinstance(e, SubAccess):
|
||||
return get_flow(e.expr, flows)
|
||||
elif isinstance(e, SubField):
|
||||
if isinstance(e.expr.typ, BundleType):
|
||||
for f in e.expr.typ.fields:
|
||||
if f.name == e.name:
|
||||
return times_g_flip(get_flow(e.expr, flows), f.flip)
|
||||
|
||||
return SourceFlow()
|
||||
|
||||
|
||||
def flip_q(t: Type) -> bool:
|
||||
def flip_rec(t: Type, f: Orientation) -> bool:
|
||||
if isinstance(t, BundleType):
|
||||
final = True
|
||||
for field in t.fields:
|
||||
final = flip_rec(field.typ, times_f_f(f, field.flip)) and final
|
||||
return final
|
||||
elif isinstance(t, VectorType):
|
||||
return flip_rec(t.typ, f)
|
||||
else:
|
||||
return isinstance(f, Flip)
|
||||
return flip_rec(t, Default())
|
||||
|
||||
def check_flow(info: Info, mname: str, flows: Dict[str, Flow], desired: Flow, e: Expression):
|
||||
flow = get_flow(e, flows)
|
||||
if isinstance(flow, SourceFlow) and isinstance(desired, SinkFlow):
|
||||
errors.append(WrongFlow(info, mname, e.serialize(), desired, flow))
|
||||
|
||||
def check_flow_e(info: Info, mname: str, flows: Dict[str, Flow], e: Expression):
|
||||
if isinstance(e, Mux):
|
||||
for _, ee in e.__dict__.items():
|
||||
if isinstance(ee, Expression):
|
||||
check_flow(info, mname, flows, SourceFlow(), ee)
|
||||
if isinstance(e, DoPrim):
|
||||
for ee in e.args:
|
||||
if isinstance(ee, Expression):
|
||||
check_flow(info, mname, flows, SourceFlow(), ee)
|
||||
|
||||
for _, ee in e.__dict__.items():
|
||||
if isinstance(ee, Expression):
|
||||
check_flow_e(info, mname, flows, ee)
|
||||
|
||||
def check_flow_s(minfo: Info, mname: str, flows: Dict[str, Flow], s: Statement):
|
||||
info = lambda s: minfo if isinstance(s, NoInfo) else s.info
|
||||
if isinstance(s, DefWire):
|
||||
flows[s.name] = DuplexFlow()
|
||||
elif isinstance(s, DefRegister):
|
||||
flows[s.name] = DuplexFlow()
|
||||
elif isinstance(s, DefMemory):
|
||||
flows[s.name] = SourceFlow()
|
||||
elif isinstance(s, DefInstance):
|
||||
flows[s.name] = SourceFlow()
|
||||
elif isinstance(s, DefNode):
|
||||
check_flow(info, mname, flows, SourceFlow(), s.value)
|
||||
flows[s.name] = SourceFlow()
|
||||
elif isinstance(s, DefMemPort):
|
||||
flows[s.name] = SinkFlow()
|
||||
elif isinstance(s, Connect):
|
||||
check_flow(info, mname, flows, SinkFlow(), s.loc)
|
||||
check_flow(info, mname, flows, SourceFlow(), s.expr)
|
||||
...
|
||||
elif isinstance(s, Conditionally):
|
||||
check_flow(info, mname, flows, SourceFlow(), s.pred)
|
||||
else:
|
||||
...
|
||||
|
||||
for _, ss in s.__dict__.items():
|
||||
if isinstance(ss, Expression):
|
||||
check_flow_e(info, mname, flows, ss)
|
||||
if isinstance(ss, Statement):
|
||||
check_flow_s(minfo, mname, flows, ss)
|
||||
|
||||
for m in c.modules:
|
||||
flows: Dict[str, Flow] = {}
|
||||
if hasattr(m, 'ports') and isinstance(m.ports, list):
|
||||
for p in m.ports:
|
||||
flows[p.name] = to_flow(p.direction)
|
||||
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
for stmt in m.body.stmts:
|
||||
check_flow_s(m.info, m.name, flows, stmt)
|
||||
|
||||
errors.trigger()
|
||||
return c
|
|
@ -0,0 +1,391 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass, PassException, Error
|
||||
from pyhcl.passes.utils import ModuleGraph, to_flow, flow, create_exps, get_info, has_flip
|
||||
from pyhcl.passes.wir import *
|
||||
|
||||
|
||||
# ScopeView
|
||||
class ScopeView:
|
||||
def __init__(self, moduleNS: set, scopes: List[set]):
|
||||
self.moduleNS = moduleNS
|
||||
self.scopes = scopes
|
||||
|
||||
def declare(self, name: str):
|
||||
self.moduleNS.add(name)
|
||||
self.scopes[0].add(name)
|
||||
|
||||
# ensures that the name cannot be used again, but prevent references to this name
|
||||
def add_to_namespace(self, name: str):
|
||||
self.moduleNS.add(name)
|
||||
|
||||
def expand_m_port_visibility(self, port: DefMemPort):
|
||||
mem_in_scopes = False
|
||||
def expand_m_port(scope: set, mp: DefMemPort):
|
||||
if mp.mem.name in scope:
|
||||
scope.add(mp.name)
|
||||
return scope
|
||||
self.scopes = list(map(lambda scope: expand_m_port(scope, port), self.scopes))
|
||||
for sx in self.scopes:
|
||||
if port.mem.name in sx:
|
||||
mem_in_scopes = True
|
||||
if mem_in_scopes is False:
|
||||
self.scopes[0].add(port.name)
|
||||
|
||||
|
||||
def legal_decl(self, name: str) -> bool:
|
||||
return name in self.moduleNS
|
||||
|
||||
def legal_ref(self, name: str) -> bool:
|
||||
for s in self.scopes:
|
||||
if name in s:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_ref(self):
|
||||
for s in self.scopes:
|
||||
print(s)
|
||||
|
||||
def child_scope(self):
|
||||
return ScopeView(self.moduleNS, [set()])
|
||||
|
||||
|
||||
def scope_view():
|
||||
return ScopeView(set(), [set()])
|
||||
|
||||
|
||||
# Custom Exceptions
|
||||
class NotUniqueException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Reference {name} does not have a unique name.')
|
||||
class InvalidLOCException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Invalid connect to an expression that is not a reference or a WritePort.')
|
||||
|
||||
class NegUIntException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] UIntLiteral cannot be negative.')
|
||||
|
||||
class UndecleardReferenceException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Reference {name} is not declared.')
|
||||
|
||||
class PoisonWithFlipException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Poison {name} cannot be a bundle type with flips.')
|
||||
|
||||
class MemWithFlipException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Memory {name} cannot be a bundle type with flips.')
|
||||
|
||||
class IllegalMemLatencyException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Memory {name} must have non-negative read latency and positive write latency.')
|
||||
|
||||
class RegWithFlipException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Register {name} cannot be a bundle type with flips.')
|
||||
|
||||
class InvalidAccessException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Invalid access to non-reference.')
|
||||
|
||||
class ModuleNameNotUniqueException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: Repeat definition of module {mname}')
|
||||
|
||||
class DefnameConflictException(PassException):
|
||||
def __init__(self, info: Info, mname: str, defname: str):
|
||||
super().__init__(f'{info}: defname {defname} of extmodule {mname} conflicts with an existing module.')
|
||||
|
||||
class DefnameDifferentPortsException(PassException):
|
||||
def __init__(self, info: Info, mname: str, defname: str):
|
||||
super().__init__(f'{info}: ports of extmodule {mname} with defname {defname} are different for an extmodule with the same defname.')
|
||||
|
||||
class DefnameDifferentPortsException(PassException):
|
||||
def __init__(self, info: Info, name: str):
|
||||
super().__init__(f'{info}: Module {name} is not defined.')
|
||||
|
||||
class IncorrectNumArgsException(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, n: int):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} requires {n} expression arguments.')
|
||||
|
||||
class IncorrectNumConstsException(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, n: int):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} requires {n} integer arguments.')
|
||||
|
||||
class NegWidthException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Width cannot be negative.')
|
||||
|
||||
class NegVecSizeException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Vector type size cannot be negative.')
|
||||
|
||||
class NegMemSizeException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Memory size cannot be negative or zero.')
|
||||
|
||||
class InstanceLoop(PassException):
|
||||
def __init__(self, info: Info, mname: str, loop: str):
|
||||
super().__init__(f'{info}: [module {mname}] Has instance loop {loop}.')
|
||||
|
||||
class NoTopModuleException(PassException):
|
||||
def __init__(self, info: Info, name: str):
|
||||
super().__init__(f'{info}: A single module must be named {name}.')
|
||||
|
||||
class NegArgException(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, value: int):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} argument {value} < 0.')
|
||||
|
||||
class LsbLargerThanMsbException(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, lsb: int, msb: int):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} lsb {lsb} > {msb}.')
|
||||
|
||||
class ResetInputException(PassException):
|
||||
def __init__(self, info: Info, mname: str, expr: Expression):
|
||||
super().__init__(f'{info}: [module {mname}] Abstract Reset not allowed as top-level input: {expr.serialize()}')
|
||||
|
||||
class ResetExtModuleOutputException(PassException):
|
||||
def __init__(self, info: Info, mname: str, expr: Expression):
|
||||
super().__init__(f'{info}: [module {mname}] Abstract Reset not allowed as ExtModule output: {expr.serialize()}')
|
||||
|
||||
class ModuleNotDefinedException(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: Module {name} is not defined.')
|
||||
|
||||
class CircuitHasNoModules(PassException):
|
||||
def __init__(self, info: Info, cname: str):
|
||||
super().__init__(f'{info}: Circuit {cname} has no modules.')
|
||||
|
||||
class ModuleHasNoPorts(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: Module {mname} has no ports.')
|
||||
|
||||
class ModuleHasNoBody(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: Module {mname} has no body.')
|
||||
|
||||
class CheckHighForm(Pass):
|
||||
def __init__(self, c: Circuit):
|
||||
self.c: Circuit = c
|
||||
self.ms: List[DefModule] = c.modules
|
||||
self.module_names: List[str] = [_.name for _ in c.modules]
|
||||
self.int_module_name: List[str] = [_.name for _ in c.modules if isinstance(_, Module)]
|
||||
self.errors: Error = Error()
|
||||
|
||||
def check_unique_module_name(self):
|
||||
for idx in range(len(self.module_names)):
|
||||
if self.int_module_name[idx] in self.int_module_name[idx:]:
|
||||
m = self.ms[idx]
|
||||
self.errors.append(ModuleNameNotUniqueException(m.info, m.name))
|
||||
|
||||
def check_extmodule(self):
|
||||
for m in self.ms:
|
||||
if isinstance(m, ExtModule) and m.name in self.int_module_name:
|
||||
self.errors.append(DefnameConflictException(m.info, m.name, m.defname))
|
||||
|
||||
def strip_width(self, typ: Type) -> Type:
|
||||
if isinstance(typ, GroundType):
|
||||
return typ.map_width(UnknownWidth)
|
||||
elif isinstance(typ, AggregateType):
|
||||
return typ.map_type(self.strip_width())
|
||||
|
||||
def check_highForm_primOp(self, info: Info, mname: str, e: DoPrim):
|
||||
def correct_num(ne, nc):
|
||||
if isinstance(ne, int) and len(e.args) != ne:
|
||||
self.errors.append(IncorrectNumArgsException(info, mname, e.op.serialize(), ne))
|
||||
|
||||
if len(e.consts) != nc:
|
||||
self.errors.append(IncorrectNumConstsException(info, mname, e.op.serialize(), nc))
|
||||
|
||||
def non_negative_consts():
|
||||
for _ in [c for c in e.consts if c < 0]:
|
||||
self.errors.append(NegArgException(info, mname, e.op.serialize(), _))
|
||||
|
||||
if isinstance(e.op, (Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Dshl, Dshr, And, Or, Xor, Cat)):
|
||||
correct_num(2, 0)
|
||||
elif isinstance(e.op, (AsUInt, AsSInt, AsClock, Cvt, Neq, Not)):
|
||||
correct_num(1, 0)
|
||||
elif isinstance(e.op, AsFixedPoint):
|
||||
correct_num(1, 1)
|
||||
elif isinstance(e.op, (Shl, Shr, Pad, Head, Tail)):
|
||||
correct_num(1, 1)
|
||||
non_negative_consts()
|
||||
elif isinstance(e.op, Bits):
|
||||
correct_num(1, 2)
|
||||
non_negative_consts()
|
||||
if len(e.consts) == 2:
|
||||
msb, lsb = e.consts[0], e.consts[1]
|
||||
if msb < lsb:
|
||||
self.errors.append(LsbLargerThanMsbException(info, mname, e.op.serialize(), lsb, msb))
|
||||
elif isinstance(e.op, (Andr, Orr, Xorr, Neg)):
|
||||
correct_num(1, 0)
|
||||
|
||||
def check_valid_loc(self, info: Info, mname: str, e: Expression):
|
||||
if isinstance(e, (UIntLiteral, SIntLiteral, DoPrim)):
|
||||
self.errors.append(InvalidLOCException(info, mname))
|
||||
|
||||
def check_instance(self, info: Info, child: str, parent: str):
|
||||
if child not in self.module_names:
|
||||
self.errors.append(ModuleNotDefinedException(info, parent, child))
|
||||
childToParent = ModuleGraph().add(parent, child)
|
||||
if childToParent is not None and len(childToParent) > 0:
|
||||
self.errors.append(InstanceLoop(info, parent, "->".join(childToParent)))
|
||||
|
||||
def check_high_form_w(self, info: Info, mname: str, w: Width):
|
||||
if isinstance(w, IntWidth) and w.width < 0:
|
||||
self.errors.append(NegWidthException(info, mname))
|
||||
|
||||
def check_high_form_t(self, info: Info, mname: str, typ: Type):
|
||||
t_attr = typ.__dict__.items()
|
||||
for _, ta in t_attr:
|
||||
if isinstance(ta, Type):
|
||||
self.check_high_form_t(info, mname, ta)
|
||||
if isinstance(ta, Width):
|
||||
self.check_high_form_w(info, mname, ta)
|
||||
|
||||
if isinstance(typ, VectorType) and typ.size < 0:
|
||||
self.errors.append(NegVecSizeException(info, mname))
|
||||
|
||||
def valid_sub_exp(self, info: Info, mname: str, e: Expression):
|
||||
if isinstance(e, (Reference, SubField, SubIndex, SubAccess)):
|
||||
...
|
||||
elif isinstance(e, (Mux, ValidIf)):
|
||||
...
|
||||
else:
|
||||
self.errors.append(InvalidAccessException(info, mname))
|
||||
|
||||
def check_high_form_e(self, info: Info, mname: str, names: ScopeView, e: Expression):
|
||||
e_attr = e.__dict__.items()
|
||||
if isinstance(e, Reference) and names.legal_ref(e.name) is False:
|
||||
self.errors.append(UndecleardReferenceException(info, mname, e.name))
|
||||
...
|
||||
elif isinstance(e, UIntLiteral) and e.value < 0:
|
||||
self.errors.append(NegUIntException(info, mname, e.name))
|
||||
elif isinstance(e, DoPrim):
|
||||
self.check_highForm_primOp(info, mname, e)
|
||||
elif isinstance(e, (Reference, UIntLiteral, Mux, ValidIf)):
|
||||
...
|
||||
elif isinstance(e, SubAccess):
|
||||
self.valid_sub_exp(info, mname, e.expr)
|
||||
else:
|
||||
for _, ea in e_attr:
|
||||
if isinstance(ea, Expression):
|
||||
self.valid_sub_exp(info, mname, ea)
|
||||
|
||||
for _, ea in e_attr:
|
||||
if isinstance(ea, Width):
|
||||
self.check_high_form_w(info, mname + '/' + e.serialize(), ea)
|
||||
if isinstance(ea, Expression):
|
||||
self.check_high_form_e(info, mname, names, ea)
|
||||
|
||||
def check_name(self, info: Info, mname: str, names: ScopeView, referenced: bool, s: Statement):
|
||||
if referenced is False:
|
||||
return
|
||||
if len(s.name) == 0:
|
||||
assert referenced is False, 'A statement with an empty name cannot be used as a reference!'
|
||||
else:
|
||||
if names.legal_decl(s.name) is True:
|
||||
self.errors.append(NotUniqueException(info, mname, s.name))
|
||||
if referenced:
|
||||
names.declare(s.name)
|
||||
else:
|
||||
names.add_to_namespace(s.name)
|
||||
|
||||
def check_high_form_s(self, minfo: Info, mname: str, names: ScopeView, s: Statement):
|
||||
s_attr = s.__dict__.items()
|
||||
t_info = get_info(s)
|
||||
info = t_info if isinstance(t_info, NoInfo) is False else minfo
|
||||
referenced = True if isinstance(s, (DefWire, DefRegister, DefInstance, DefMemory, DefNode, Port)) else False
|
||||
self.check_name(info, mname, names, referenced, s)
|
||||
if isinstance(s, DefRegister):
|
||||
if has_flip(s.typ):
|
||||
self.errors.append(RegWithFlipException(info, mname, s.name))
|
||||
elif isinstance(s, DefMemory):
|
||||
if has_flip(s.memType.typ):
|
||||
self.errors.append(MemWithFlipException(info, mname, s.name))
|
||||
if s.memType.size < 0:
|
||||
self.errors.append(NegMemSizeException(info, mname))
|
||||
elif isinstance(s, DefInstance):
|
||||
self.check_instance(info, mname, s.module)
|
||||
elif isinstance(s, Connect):
|
||||
self.check_valid_loc(info, mname, s.loc)
|
||||
elif isinstance(s, DefMemPort):
|
||||
names.expand_m_port_visibility(s)
|
||||
elif isinstance(s, Block):
|
||||
for stmt in s.stmts:
|
||||
self.check_high_form_s(info, mname, names, stmt)
|
||||
else:
|
||||
...
|
||||
|
||||
for _, sa in s_attr:
|
||||
if isinstance(sa, Type):
|
||||
self.check_high_form_t(info, mname, sa)
|
||||
elif isinstance(sa, Expression):
|
||||
self.check_high_form_e(info, mname, names, sa)
|
||||
|
||||
if isinstance(s, Conditionally):
|
||||
self.check_high_form_s(minfo, mname, names, s.conseq)
|
||||
self.check_high_form_s(minfo, mname, names, s.alt)
|
||||
else:
|
||||
for _, sa in s_attr:
|
||||
if isinstance(sa, Statement):
|
||||
self.check_high_form_s(minfo, mname, names, sa)
|
||||
|
||||
|
||||
|
||||
def check_high_form_p(self, mname: str, names: ScopeView, p: Port):
|
||||
if names.legal_decl(p.name) is True:
|
||||
self.errors.append(NotUniqueException(NoInfo, mname, p.name))
|
||||
names.declare(p.name)
|
||||
self.check_high_form_t(p.info, mname, p.typ)
|
||||
|
||||
def find_bad_reset_type_ports(self, m: DefModule, dir: Direction):
|
||||
bad_reset_type_ports = []
|
||||
bad = to_flow(dir)
|
||||
gen = ((create_exps(ref), p1) for (ref, p1) in [(Reference(p.name, p.typ), p) for p in m.ports])
|
||||
for expr, port in gen:
|
||||
if isinstance(expr, list):
|
||||
for exx in expr:
|
||||
if exx is not None and exx.typ == ResetType and flow(exx) == bad:
|
||||
bad_reset_type_ports.append((port, exx))
|
||||
else:
|
||||
if expr is not None and expr.typ == ResetType and flow(expr) == bad:
|
||||
bad_reset_type_ports.append((port, expr))
|
||||
|
||||
|
||||
return bad_reset_type_ports
|
||||
|
||||
def check_high_form_m(self, m: DefModule):
|
||||
names = scope_view()
|
||||
if hasattr(m, 'ports') and isinstance(m.ports, list):
|
||||
for p in m.ports:
|
||||
self.check_high_form_p(m.name, names, p)
|
||||
else:
|
||||
self.errors.append(ModuleHasNoPorts(m.info, m.name))
|
||||
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
for s in m.body.stmts:
|
||||
self.check_high_form_s(m.info, m.name, names, s)
|
||||
else:
|
||||
self.errors.append(ModuleHasNoBody(m.info, m.name))
|
||||
|
||||
if isinstance(m, ExtModule):
|
||||
for port, expr in self.find_bad_reset_type_ports(m, Output):
|
||||
self.errors.append(ResetExtModuleOutputException(port.info, m.name, expr))
|
||||
|
||||
def run(self):
|
||||
if hasattr(self.c, 'modules') and isinstance(self.c.modules, list):
|
||||
for m in self.c.modules:
|
||||
self.check_high_form_m(m)
|
||||
else:
|
||||
self.errors.append(CircuitHasNoModules(self.c.info, self.c.main))
|
||||
|
||||
if self.c.main not in self.int_module_name:
|
||||
self.errors.append(NoTopModuleException(self.c.info, self.c.main))
|
||||
|
||||
self.errors.trigger()
|
||||
return self.c
|
|
@ -0,0 +1,288 @@
|
|||
from typing import List
|
||||
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass, PassException, Error
|
||||
from pyhcl.passes.utils import times_f_f
|
||||
from pyhcl.passes.wir import WrappedType
|
||||
|
||||
# Custom Exceptions
|
||||
class SubfieldNotInBundle(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Subfield {name} is not in bundle.')
|
||||
|
||||
class SubfieldOnNonBundle(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Subfield {name} is accessed on non-bundle.')
|
||||
|
||||
class IndexTooLarge(PassException):
|
||||
def __init__(self, info: Info, mname: str, value: int):
|
||||
super().__init__(f'{info}: [module {mname}] Index with value {value} is too large.')
|
||||
|
||||
class IndexOnNonVector(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Index illegal on non-vector type.')
|
||||
|
||||
class AccessIndexNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Access index must be a UInt type')
|
||||
|
||||
class IndexNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Index is not of UIntType.')
|
||||
|
||||
class EnableNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Enable is not of UIntType.')
|
||||
|
||||
class InvalidConnect(PassException):
|
||||
def __init__(self, info: Info, mname: str, cond: str, lhs: Expression, rhs: Expression):
|
||||
ltyp = f'\t{lhs.serialize()}: {lhs.typ.serialize()}'
|
||||
rtyp = f'\t{rhs.serialize()}: {rhs.typ.serialize()}'
|
||||
super().__init__(f'{info}: [module {mname}] Type mismatch in \'{cond}\'.\n{ltyp}\n{rtyp}')
|
||||
|
||||
class ReqClk(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Requires a clock typed signal.')
|
||||
|
||||
class RegReqClk(PassException):
|
||||
def __init__(self, info: Info, mname: str, name: str):
|
||||
super().__init__(f'{info}: [module {mname}] Register {name} requires a clock typed signal.')
|
||||
|
||||
class EnNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Enable must be a 1-bit UIntType typed signal.')
|
||||
|
||||
class PredNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Predicate not a 1-bit UIntType.')
|
||||
|
||||
class OpNotGround(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} cannot operate on non-ground types.')
|
||||
|
||||
class OpNotUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, e: str):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} requires argument {e} to be a UInt type.')
|
||||
|
||||
class InvalidRegInit(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Type of init must match type of DefRegister.')
|
||||
|
||||
class OpNotAllUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} requires all arguments to be UInt type.')
|
||||
|
||||
class OpNotAllSameType(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} requires all operands to have the same type.')
|
||||
|
||||
class OpNotCorrectType(PassException):
|
||||
def __init__(self, info: Info, mname: str, op: str, typs: List[str]):
|
||||
super().__init__(f'{info}: [module {mname}] Primop {op} does not have correct arg types: {typs}.')
|
||||
|
||||
class NodePassiveType(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Node must be a passive type.')
|
||||
|
||||
class MuxSameType(PassException):
|
||||
def __init__(self, info: Info, mname: str, t1: str, t2: str):
|
||||
super().__init__(f'{info}: [module {mname}] Must mux between equivalent types: {t1} != {t2}.')
|
||||
|
||||
class MuxPassiveTypes(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Must mux between passive types.')
|
||||
|
||||
class MuxCondUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] A mux condition must be of type 1-bit UInt.')
|
||||
|
||||
class MuxClock(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Firrtl does not support muxing clocks.')
|
||||
|
||||
class ValidIfPassiveTypes(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] Must validif a passive type.')
|
||||
|
||||
class ValidIfCondUInt(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [module {mname}] A validif condition must be of type UInt.')
|
||||
|
||||
class IllegalResetType(PassException):
|
||||
def __init__(self, info: Info, mname: str, exp: str):
|
||||
super().__init__(f'{info}: [module {mname}] Register resets must have type Reset, AsyncReset, or UInt<1>: {exp}.')
|
||||
|
||||
class IllegalUnknownType(PassException):
|
||||
def __init__(self, info: Info, mname: str, exp: str):
|
||||
super().__init__(f'{info}: [module {mname}] Uninferred type: {exp}.')
|
||||
|
||||
# TODO PrintfArgNotGround | OpNoMixFix | OpNotAnalog | IllegalAnalogDeclaration | IllegalAttachExp
|
||||
|
||||
class CheckTypes(Pass):
|
||||
def legal_reset_type(self, typ: Type) -> bool:
|
||||
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
|
||||
return typ.width.width == 1
|
||||
elif isinstance(typ, AsyncResetType):
|
||||
return True
|
||||
elif isinstance(typ, ResetType):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def legal_cond_type(self, typ: Type) -> bool:
|
||||
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
|
||||
return typ.width.width == 1
|
||||
elif isinstance(typ, UIntType):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def bulk_equals(self, t1: Type, t2: Type, flip1: Orientation, flip2: Orientation) -> bool:
|
||||
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
|
||||
return flip1 == flip2
|
||||
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
|
||||
return flip1 == flip2
|
||||
elif isinstance(t1, SIntType) and isinstance(t2, SIntType):
|
||||
return flip1 == flip2
|
||||
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
|
||||
return flip1 == flip2
|
||||
elif isinstance(t1, ResetType):
|
||||
return self.legal_reset_type(t2) and flip1 == flip2
|
||||
elif isinstance(t2, ResetType):
|
||||
return self.legal_reset_type(t1) and flip1 == flip2
|
||||
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
|
||||
t1_fields = {}
|
||||
for f1 in t1.fields:
|
||||
t1_fields[f1.name] = (f1.typ, f1.flip)
|
||||
for f2 in f2.fields:
|
||||
if f2.name in t1_fields.keys():
|
||||
t1_flip = t1_fields[f2.name].typ
|
||||
return self.bulk_equals(t1_flip, f2.typ, times_f_f(flip1, t1_flip) ,times_f_f(flip2, f2.flip))
|
||||
else:
|
||||
return True
|
||||
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
|
||||
return self.bulk_equals(t1.typ, t2.typ, flip1, flip2)
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def valid_connect(locTyp: Type, exprTyp: Type) -> bool:
|
||||
if isinstance(locTyp, (ClockType, UIntType)) and isinstance(exprTyp, (ClockType, UIntType)):
|
||||
return True
|
||||
return type(locTyp) == type(exprTyp)
|
||||
|
||||
def valid_connects(self, c: Connect) -> bool:
|
||||
return CheckTypes.valid_connect(c.loc.typ, c.expr.typ)
|
||||
|
||||
def run(self, c: Circuit):
|
||||
errors = Error()
|
||||
|
||||
def passive(t: Type) -> bool:
|
||||
if isinstance(t, (UIntType, SIntType)):
|
||||
return True
|
||||
elif isinstance(t, VectorType):
|
||||
return passive(t.typ)
|
||||
elif isinstance(t, BundleType):
|
||||
final = True
|
||||
for f in t.fields:
|
||||
final = f.flip == Default and passive(f.typ) and final
|
||||
return final
|
||||
else:
|
||||
return True
|
||||
|
||||
def check_typs_primop(info: Info, mname: str, e: DoPrim):
|
||||
def check_all_typs(exprs: List[Expression], okUInt: bool, okSInt: bool, okClock: bool, okAsync: bool):
|
||||
for expr in exprs:
|
||||
if isinstance(expr.typ, UIntType) and okUInt is False:
|
||||
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
|
||||
elif isinstance(expr.typ, UIntType) and okSInt is False:
|
||||
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
|
||||
elif isinstance(expr.typ, ClockType) and okClock is False:
|
||||
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
|
||||
elif isinstance(expr.typ, AsyncResetType) and okAsync is False:
|
||||
errors.append(OpNotCorrectType(info, mname, e.op.serialize(), [expr.typ.serialize() for expr in exprs]))
|
||||
|
||||
if isinstance(e.op, (AsUInt, AsSInt, AsClock, AsyncResetType)):
|
||||
...
|
||||
elif isinstance(e.op, (Dshl, Dshr)):
|
||||
check_all_typs(list(e.args[0]), True, True, False, False)
|
||||
check_all_typs(e.args[1:], True, False, False, False)
|
||||
elif isinstance(e.op, (Add, Sub, Mul, Lt, Leq, Gt, Geq, Eq, Neq)):
|
||||
check_all_typs(e.args, True, True, False, False)
|
||||
elif isinstance(e.op, (Pad, Bits, Head, Tail)):
|
||||
check_all_typs(e.args, True, True, False, False)
|
||||
elif isinstance(e.op, (Shr, Shl, Cat)):
|
||||
check_all_typs(e.args, True, True, False, False)
|
||||
else:
|
||||
check_all_typs(e.args, True, True, False, False)
|
||||
|
||||
def check_types_e(info: Info, mname: str, e: Expression):
|
||||
if isinstance(e, DoPrim):
|
||||
check_typs_primop(info, mname, e)
|
||||
elif isinstance(e, Mux):
|
||||
if WrappedType(e.tval.typ) != WrappedType(e.fval.typ):
|
||||
errors.append(MuxSameType(info, mname, e.tval.typ.serialize(), e.fval.typ.serialize()))
|
||||
if passive(e.typ) is False:
|
||||
errors.append(MuxPassiveTypes(info, mname))
|
||||
if self.legal_cond_type(e.cond.typ) is False:
|
||||
errors.append(MuxCondUInt(info, mname))
|
||||
elif isinstance(e, ValidIf):
|
||||
if passive(e.typ) is False:
|
||||
errors.append(ValidIfPassiveTypes(info, mname))
|
||||
if isinstance(e.cond.typ, UIntType):
|
||||
...
|
||||
else:
|
||||
errors.append(ValidIfCondUInt(info, mname))
|
||||
else:
|
||||
...
|
||||
|
||||
for _, ee in e.__dict__.items():
|
||||
if isinstance(ee, Expression):
|
||||
check_types_e(info, mname, ee)
|
||||
|
||||
def check_types_s(minfo: Info, mname: str, s: Statement):
|
||||
def get_info(s):
|
||||
if isinstance(s, NoInfo):
|
||||
return minfo
|
||||
else:
|
||||
return s
|
||||
|
||||
if isinstance(s, Connect) and self.valid_connects(s) is False:
|
||||
con_msg = Connect(s.loc, s.expr, NoInfo()).serialize()
|
||||
errors.append(InvalidConnect(get_info(s), mname, con_msg, s.loc, s.expr))
|
||||
elif isinstance(s, DefRegister):
|
||||
if isinstance(s.init, Expression) and WrappedType(s.typ) != WrappedType(s.init.typ):
|
||||
errors.append(InvalidRegInit(get_info(s), mname))
|
||||
if isinstance(s.init, Expression) and CheckTypes.valid_connect(s.typ, s.init.typ) is False:
|
||||
con_msg = Connect(s.loc, s.expr, NoInfo()).serialize()
|
||||
errors.append(InvalidConnect(get_info(s), mname, con_msg, Reference(s.name, s.typ), s.init))
|
||||
|
||||
if isinstance(s.init, Expression) and self.legal_reset_type(s.reset.typ) is False:
|
||||
errors.append(IllegalResetType(get_info(s), mname, s.name))
|
||||
if not isinstance(s.clock.typ, ClockType) or not isinstance(s.clock.typ.width, IntWidth) or s.clock.typ.width.width != 1:
|
||||
errors.append(RegReqClk(get_info(s), mname, s.name))
|
||||
elif isinstance(s, Conditionally) and self.legal_cond_type(s.pred.typ) is False:
|
||||
errors.append(PredNotUInt(get_info(s), mname))
|
||||
elif isinstance(s, DefNode):
|
||||
if passive(s.value.typ) is False:
|
||||
errors.append(NodePassiveType(get_info(s), mname))
|
||||
elif isinstance(s, DefMemory):
|
||||
...
|
||||
else:
|
||||
...
|
||||
|
||||
for _, ss in s.__dict__.items():
|
||||
if isinstance(ss, Statement):
|
||||
check_types_s(get_info(s), mname, ss)
|
||||
if isinstance(ss, Expression):
|
||||
check_types_e(get_info(s), mname, ss)
|
||||
|
||||
for m in c.modules:
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
for s in m.body.stmts:
|
||||
check_types_s(m.info, m.name, s)
|
||||
|
||||
errors.trigger()
|
||||
return c
|
|
@ -0,0 +1,181 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass, PassException, Error
|
||||
from pyhcl.passes.utils import get_width, get_info
|
||||
from pyhcl.passes.check_types import IllegalResetType, CheckTypes, InvalidConnect
|
||||
|
||||
# MaxWidth
|
||||
maxWidth = 1000000
|
||||
|
||||
class UninferredWidth(PassException):
|
||||
def __init__(self, info: Info, target: str):
|
||||
super().__init__(f'{info}: Uninferred width for target below. (Did you forget to assign to it?) \n{target}')
|
||||
|
||||
class InvalidRange(PassException):
|
||||
def __init__(self, info: Info, target: str, i: Type):
|
||||
super().__init__(f'{info}: Invalid range {i.serialize()} for target below. (Are the bounds valid?) \n{target}')
|
||||
|
||||
class WidthTooSmall(PassException):
|
||||
def __init__(self, info: Info, mname: str, b: int):
|
||||
super().__init__(f'{info} : [target {mname}] Width too small for constant {b}.')
|
||||
|
||||
class WidthTooBig(PassException):
|
||||
def __init__(self, info: Info, mname: str, b: int):
|
||||
super().__init__(f'{info} : [target ${mname}] Width {b} greater than max allowed width of {maxWidth} bits')
|
||||
|
||||
class DshlTooBig(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info} : [target {mname}] Width of dshl shift amount must be less than {maxWidth} bits.')
|
||||
|
||||
class MultiBitAsClock(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to a Clock.')
|
||||
|
||||
class MultiBitAsAsyncReset(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info} : [target {mname}] Cannot cast a multi-bit signal to an AsyncReset.')
|
||||
|
||||
class NegWidthException(PassException):
|
||||
def __init__(self, info: Info, mname: str):
|
||||
super().__init__(f'{info}: [target {mname}] Width cannot be negative or zero.')
|
||||
|
||||
class BitsWidthException(PassException):
|
||||
def __init__(self, info: Info, mname: str, hi: int, width: int, exp: str):
|
||||
super().__init__(f'{info}: [target {mname}] High bit {hi} in bits operator is larger than input width {width} in {exp}.')
|
||||
|
||||
class HeadWidthException(PassException):
|
||||
def __init__(self, info: Info, mname: str, n: int, width: int):
|
||||
super().__init__(f'{info}: [target {mname}] Parameter {n} in head operator is larger than input width {width}.')
|
||||
|
||||
class TailWidthException(PassException):
|
||||
def __init__(self, info: Info, mname: str, n: int, width: int):
|
||||
super().__init__(f'{info}: [target {mname}] Parameter {n} in tail operator is larger than input width {width}.')
|
||||
|
||||
class CheckWidths(Pass):
|
||||
|
||||
def run(self, c: Circuit):
|
||||
errors = Error()
|
||||
|
||||
def gen_target(name: str, subname: str) -> str:
|
||||
return f'{name}-{subname}'
|
||||
|
||||
def check_width_w(info: Info, target: str, t: Type, w: Width):
|
||||
if isinstance(w, IntWidth) and w.width >= maxWidth:
|
||||
errors.append(WidthTooBig(info, target, w.width))
|
||||
elif isinstance(w, IntWidth):
|
||||
...
|
||||
else:
|
||||
errors.append(UninferredWidth(info, target))
|
||||
|
||||
def has_width(typ: Type) -> bool:
|
||||
if isinstance(typ, GroundType) and hasattr(typ, 'width') and isinstance(typ.width, IntWidth):
|
||||
return True
|
||||
elif isinstance(typ, GroundType):
|
||||
return False
|
||||
else:
|
||||
raise PassException(f'hasWidth - {typ}')
|
||||
|
||||
def check_width_t(info: Info, target: str, t: Type):
|
||||
if isinstance(t, BundleType):
|
||||
for f in t.fields:
|
||||
check_width_f(info, target, f)
|
||||
else:
|
||||
for _, tt in t.__dict__.items():
|
||||
if isinstance(tt, Type):
|
||||
check_width_t(info, target, tt)
|
||||
|
||||
for _, tt in t.__dict__.items():
|
||||
if isinstance(tt, Width):
|
||||
check_width_w(info, target, t, tt)
|
||||
|
||||
def check_width_f(info: Info, target: str, f: Field):
|
||||
check_width_t(info, target, f.typ)
|
||||
|
||||
def check_width_e_leaf(info: Info, target: str, expr: Expression):
|
||||
if isinstance(expr, UIntLiteral) and get_binary_width(expr.value) > get_width(expr.width):
|
||||
errors.append(WidthTooSmall(info, target, expr.value))
|
||||
elif isinstance(expr, SIntLiteral) and get_binary_width(expr.value) + 1 > get_width(expr.width):
|
||||
errors.append(WidthTooSmall(info, target, expr.value))
|
||||
elif isinstance(expr, DoPrim) and len(expr.args) == 2:
|
||||
if isinstance(expr.op, Dshl) and has_width(expr.args[0].typ) and get_width(expr.args[1].typ.width) > maxWidth:
|
||||
errors.append(DshlTooBig(info, target))
|
||||
elif isinstance(expr, DoPrim) and len(expr.args) == 1:
|
||||
if isinstance(expr.op, Bits) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.consts[0]:
|
||||
errors.append(BitsWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width), expr.serialize()))
|
||||
elif isinstance(expr.op, Head) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
|
||||
errors.append(HeadWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
|
||||
elif isinstance(expr.op, Tail) and has_width(expr.args[0].typ) and get_width(expr.args[0].typ.width) <= expr.args[0]:
|
||||
errors.append(TailWidthException(info, target, expr.consts[0], get_width(expr.args[0].typ.width)))
|
||||
elif isinstance(expr.op, AsClock) and expr.consts[0] != 1:
|
||||
errors.append(MultiBitAsClock(info, target))
|
||||
|
||||
def check_width_e(info: Info, target: str, rec_depth: int, e: Expression):
|
||||
check_width_e_leaf(info, target, e)
|
||||
if isinstance(e, (Mux, ValidIf, DoPrim)):
|
||||
if rec_depth > 0:
|
||||
for _, ee in e.__dict__.items():
|
||||
if isinstance(ee, Expression):
|
||||
check_width_e(info, target, rec_depth - 1, ee)
|
||||
else:
|
||||
check_width_e_dfs(info, target, e)
|
||||
|
||||
def check_width_e_dfs(info: Info, target: str, expr: Expression):
|
||||
stack = expr.__dict__.items()
|
||||
def push(e: Expression):
|
||||
stack.append(e)
|
||||
while len(stack) > 0:
|
||||
current = stack
|
||||
check_width_e_leaf(info, target, current)
|
||||
for _, leaf in current.__dict__.items():
|
||||
if isinstance(leaf, Expression):
|
||||
push(leaf)
|
||||
|
||||
|
||||
def check_width_s(minfo: Info, target: str, s: Statement):
|
||||
info = get_info(s)
|
||||
if isinstance(info, NoInfo):
|
||||
info = minfo
|
||||
for _, ss in s.__dict__.items():
|
||||
if isinstance(ss, Expression):
|
||||
check_width_e(info, target, 4, ss)
|
||||
if isinstance(ss, Statement):
|
||||
check_width_s(info, target, ss)
|
||||
if isinstance(ss, Type):
|
||||
check_width_t(info, target, ss)
|
||||
|
||||
if isinstance(s, DefRegister):
|
||||
sx = s.reset.typ if isinstance(s.reset, Expression) else None
|
||||
if sx is None:
|
||||
...
|
||||
elif isinstance(sx, UIntType) and get_width(sx.width) == 1:
|
||||
...
|
||||
elif isinstance(sx, AsyncResetType):
|
||||
...
|
||||
elif isinstance(sx, ResetType):
|
||||
...
|
||||
else:
|
||||
errors.append(IllegalResetType(info, target, s.name))
|
||||
|
||||
if isinstance(s.init, Expression) and CheckTypes.valid_connect(s.typ, s.init.typ) is False:
|
||||
con_msg = DefRegister(s.name, s.typ, s.clock, s.reset, s.init, NoInfo())
|
||||
errors.append(InvalidConnect(info, target, con_msg, _, s.init))
|
||||
|
||||
def check_width_p(minfo: Info, target: str, p: Port):
|
||||
check_width_t(p.info, target, p.typ)
|
||||
|
||||
def check_width_m(target: str, m: DefModule):
|
||||
if hasattr(m, 'ports') and isinstance(m.ports, list):
|
||||
for p in m.ports:
|
||||
check_width_p(m.info, gen_target(target, m.name), p)
|
||||
|
||||
if hasattr(m, 'body') and isinstance(m.body, Block):
|
||||
if hasattr(m.body, 'stmts') and isinstance(m.body.stmts, list):
|
||||
for s in m.body.stmts:
|
||||
check_width_s(m.info, gen_target(target, m.name), s)
|
||||
|
||||
for m in c.modules:
|
||||
check_width_m(c.main, m)
|
||||
|
||||
errors.trigger()
|
||||
return c
|
|
@ -0,0 +1,150 @@
|
|||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
|
||||
@dataclass
|
||||
class ExpandAggregate(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
modules: List[DefModule] = []
|
||||
|
||||
def flip_direction(d: Direction) -> Direction:
|
||||
if isinstance(d, Output):
|
||||
return Input()
|
||||
else:
|
||||
return Output()
|
||||
|
||||
def flatten_vector(name: str, t: Type) -> list:
|
||||
decs = []
|
||||
if isinstance(t, VectorType):
|
||||
for nx, tx in [(f"{name}_{i}", t.typ) for i in range(t.size)]:
|
||||
if isinstance(tx, VectorType):
|
||||
decs = decs + flatten_vector(nx, tx)
|
||||
elif isinstance(tx, BundleType):
|
||||
decs = decs + [(nxx, txx) for nxx, _, txx in flatten_bundle(nx, tx)]
|
||||
else:
|
||||
decs.append((nx, tx))
|
||||
return decs
|
||||
|
||||
def flatten_bundle(name: str, t: Type) -> list:
|
||||
decs = []
|
||||
if isinstance(t, BundleType):
|
||||
for nx, fx, tx in [(f"{name}_{f.name}", f.flip, f.typ) for f in t.fields]:
|
||||
if isinstance(tx, BundleType):
|
||||
decs = decs + flatten_bundle(nx, tx)
|
||||
elif isinstance(tx, VectorType):
|
||||
decs = decs + [(nxx, fx, txx) for nxx, txx in flatten_vector(nx, tx)]
|
||||
else:
|
||||
decs.append((nx, fx, tx))
|
||||
return decs
|
||||
|
||||
def expand_aggregate_wire(stmt: Statement, stmts: List[Statement]):
|
||||
if isinstance(stmt.typ, VectorType):
|
||||
typs = flatten_vector(stmt.name, stmt.typ)
|
||||
for nx, tx in typs:
|
||||
stmts.append(DefWire(nx, tx, stmt.info))
|
||||
elif isinstance(stmt.typ, BundleType):
|
||||
typs = flatten_bundle(stmt.name, stmt.typ)
|
||||
for nx, _, tx in typs:
|
||||
stmts.append(DefWire(nx, tx, stmt.info))
|
||||
else:
|
||||
stmts.append(stmt)
|
||||
|
||||
def expand_aggregate_node(stmt: Statement, stmts: List[Statement]):
|
||||
value = stmt.value
|
||||
if isinstance(value.typ, VectorType):
|
||||
if isinstance(value, Mux):
|
||||
tval, fval = value.tval, value.fval
|
||||
tval_typs, fval_typs = flatten_vector(tval.name, tval.typ), flatten_vector(fval.name, fval.typ)
|
||||
typs = flatten_vector(stmt.name, value.typ)
|
||||
for i in range(len(typs)):
|
||||
stmts.append(DefNode(typs[i][0], Mux(value.cond,
|
||||
Reference(tval_typs[i][0], tval_typs[i][1]), Reference(fval_typs[i][0], fval_typs[i][1]), typs[i][1])))
|
||||
if isinstance(value, ValidIf):
|
||||
val = value.value
|
||||
val_typs = flatten_vector(val.name, val.value)
|
||||
typs = flatten_vector(stmt.name, value.typ)
|
||||
for i in range(len(typs)):
|
||||
stmts.append(DefNode(typs[i][0], ValidIf(value.cond,
|
||||
Reference(val_typs[i][0], val_typs[i][1]), typs[i][1])))
|
||||
if isinstance(value, DoPrim):
|
||||
...
|
||||
elif isinstance(value.typ, BundleType):
|
||||
if isinstance(value, Mux):
|
||||
tval, fval = value.tval, value.fval
|
||||
tval_typs, fval_typs = flatten_bundle(tval.name, tval.typ), flatten_bundle(fval.name, fval.typ)
|
||||
typs = flatten_bundle(stmt.name, value.typ)
|
||||
for i in range(len(typs)):
|
||||
stmts.append(DefNode(typs[i][0], Mux(value.cond,
|
||||
Reference(tval_typs[i][0], tval_typs[i][2]), Reference(fval_typs[i][0], fval_typs[i][2]), typs[i][2])))
|
||||
if isinstance(value, ValidIf):
|
||||
val = value.value
|
||||
val_typs = flatten_bundle(val.name, val.value)
|
||||
typs = flatten_bundle(stmt.name, value.typ)
|
||||
for i in range(len(typs)):
|
||||
stmts.append(DefNode(typs[i][0], ValidIf(value.cond,
|
||||
Reference(val_typs[i][0], val_typs[i][2]), typs[i][2])))
|
||||
if isinstance(value, DoPrim):
|
||||
...
|
||||
else:
|
||||
stmts.append(stmt)
|
||||
|
||||
def expand_aggregate_reg(stmt: Statement, stmts: List[Statement]):
|
||||
typ = stmt.typ
|
||||
if isinstance(typ, VectorType):
|
||||
typs = flatten_vector(stmt.name, typ)
|
||||
for nx, tx in typs:
|
||||
init = Reference(nx, tx)
|
||||
stmts.append(DefRegister(nx, tx, stmt.clock, stmt.reset, init, stmt.info))
|
||||
elif isinstance(typ, BundleType):
|
||||
typs = flatten_bundle(stmt.name, stmt.typ)
|
||||
for nx, _, tx in typs:
|
||||
stmts.append(DefRegister(nx, tx, stmt.clock, stmt.reset, stmt.init, stmt.info))
|
||||
else:
|
||||
stmts.append(stmt)
|
||||
|
||||
def expand_aggregate_s(stmts: List[Statement]) -> List[Statement]:
|
||||
new_stmts = []
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, DefWire):
|
||||
expand_aggregate_wire(stmt, new_stmts)
|
||||
elif isinstance(stmt, DefNode):
|
||||
expand_aggregate_node(stmt, new_stmts)
|
||||
elif isinstance(stmt, DefRegister):
|
||||
expand_aggregate_reg(stmt, new_stmts)
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
return new_stmts
|
||||
|
||||
def expand_aggregate_p(p: Port, ports: List[Port]):
|
||||
if isinstance(p.typ, VectorType):
|
||||
typs = flatten_vector(p.name, p.typ)
|
||||
for nx, tx in typs:
|
||||
ports.append(Port(nx, p.direction, tx, p.info))
|
||||
elif isinstance(p.typ, BundleType):
|
||||
typs = flatten_bundle(p.name, p.typ)
|
||||
for nx, fx, tx in typs:
|
||||
dir = p.direction if isinstance(fx, Default) else flip_direction(p.direction)
|
||||
ports.append(Port(nx, dir, tx, p.info))
|
||||
else:
|
||||
ports.append(p)
|
||||
|
||||
def expand_aggregate_ps(ps: List[Port]) -> List[Port]:
|
||||
new_ports = []
|
||||
for p in ps:
|
||||
expand_aggregate_p(p, new_ports)
|
||||
return new_ports
|
||||
|
||||
def expand_aggregate_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, ExtModule):
|
||||
return m
|
||||
if not hasattr(m, 'body') or not isinstance(m.body, Block):
|
||||
return m
|
||||
if not hasattr(m.body, 'stmts') or not isinstance(m.body.stmts, list):
|
||||
return m
|
||||
|
||||
return Module(m.name, expand_aggregate_ps(m.ports), Block(expand_aggregate_s(m.body.stmts)), m.typ, m.info)
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(expand_aggregate_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,97 @@
|
|||
from typing import List, Dict
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from pyhcl.passes.utils import get_binary_width
|
||||
|
||||
DEFAULT_READ_LATENCY = 0
|
||||
DEFAULT_WRITE_LATENCY = 1
|
||||
|
||||
@dataclass
|
||||
class ExpandMemory(Pass):
|
||||
def run(self, c: Circuit):
|
||||
|
||||
def get_mem_ports(stmts: List[Statement], writes: Dict[str, List[Statement]], reads: Dict[str, List[Statement]]):
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, DefMemPort):
|
||||
if stmt.rw is True:
|
||||
if stmt.mem.name in reads:
|
||||
reads[stmt.mem.name] = reads[stmt.mem.name] + [stmt.name]
|
||||
else:
|
||||
reads[stmt.mem.name] = [stmt.name]
|
||||
else:
|
||||
if stmt.mem.name in writes:
|
||||
writes[stmt.mem.name] = writes[stmt.mem.name] + [stmt.name]
|
||||
else:
|
||||
writes[stmt.mem.name] = [stmt.name]
|
||||
def expand_mem_port(stmts: List[Statement], target: Statement):
|
||||
addr_width = IntWidth(get_binary_width(target.mem.typ.size))
|
||||
# addr
|
||||
stmts.append(Connect(
|
||||
SubField(SubField(Reference(target.mem.name, UIntType(addr_width)),target.name, UIntType(addr_width)), 'addr', UIntType(addr_width)),
|
||||
UIntLiteral(target.index.value, addr_width)))
|
||||
# en
|
||||
stmts.append(Connect(
|
||||
SubField(SubField(Reference(target.mem.name, UIntType(IntWidth(1))),target.name, UIntType(IntWidth(1))), 'en', UIntType(IntWidth(1))),
|
||||
UIntLiteral(1, IntWidth(1))))
|
||||
# clk
|
||||
stmts.append(Connect(
|
||||
SubField(SubField(Reference(target.mem.name, ClockType()),target.name, ClockType()), 'clk', ClockType()),
|
||||
target.clk))
|
||||
# mask
|
||||
if target.rw is False:
|
||||
stmts.append(Connect(
|
||||
SubField(SubField(Reference(target.mem.name, UIntType(IntWidth(1))),target.name, UIntType(IntWidth(1))), 'mask', UIntType(IntWidth(1))),
|
||||
UIntLiteral(1, IntWidth(1))))
|
||||
|
||||
def expand_memory_e(s: Statement, ports: Dict[str, Statement]) -> Statement:
|
||||
loc, expr = s.loc, s.expr
|
||||
if isinstance(loc, Reference) and loc.name in ports:
|
||||
loc = SubField(SubField(Reference(ports[loc.name].mem.name, loc.typ), loc.name, loc.typ), 'data', loc.typ)
|
||||
elif isinstance(expr, Reference) and expr.name in ports:
|
||||
expr = SubField(SubField(Reference(ports[expr.name].mem.name, expr.typ), expr.name, expr.typ), 'data', expr.typ)
|
||||
return Connect(loc, expr, s.info, s.blocking, s.bidirection, s.mem)
|
||||
|
||||
def expand_memory_s(stmts: List[Statement]) -> List[Statement]:
|
||||
new_stmts: List[Statement] = []
|
||||
writes: Dict[str, List[Statement]] = {}
|
||||
reads: Dict[str, List[Statement]] = {}
|
||||
ports: Dict[str, List[Statement]] = {}
|
||||
get_mem_ports(stmts, writes, reads)
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, DefMemory):
|
||||
new_stmts.append(WDefMemory(
|
||||
stmt.name,
|
||||
stmt.memType,
|
||||
stmt.memType.typ,
|
||||
stmt.memType.size,
|
||||
DEFAULT_READ_LATENCY,
|
||||
DEFAULT_WRITE_LATENCY,
|
||||
reads[stmt.name],
|
||||
writes[stmt.name]))
|
||||
elif isinstance(stmt, DefMemPort):
|
||||
expand_mem_port(new_stmts, stmt)
|
||||
ports[stmt.name] = stmt
|
||||
elif isinstance(stmt, Connect):
|
||||
new_stmts.append(expand_memory_e(stmt, ports))
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
return new_stmts
|
||||
|
||||
def expand_memory_m(m: DefModule) -> DefModule:
|
||||
return Module(
|
||||
m.name,
|
||||
m.ports,
|
||||
Block(expand_memory_s(m.body.stmts)),
|
||||
m.typ,
|
||||
m.info
|
||||
)
|
||||
|
||||
new_modules = []
|
||||
for m in c.modules:
|
||||
if isinstance(m, Module):
|
||||
new_modules.append(expand_memory_m(m))
|
||||
else:
|
||||
new_modules.append(m)
|
||||
|
||||
return Circuit(new_modules, c.main, c.info)
|
|
@ -0,0 +1,122 @@
|
|||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from pyhcl.passes._pass import Pass
|
||||
|
||||
@dataclass
|
||||
class ExpandSequential(Pass):
|
||||
|
||||
def run(self, c: Circuit):
|
||||
modules: List[Module] = []
|
||||
blocks: List[Statement] = []
|
||||
block_map: Dict[str, List[Statement]] = {}
|
||||
clock_map: Dict[str, Expression] = {}
|
||||
|
||||
def get_ref_name(e: Expression) -> str:
|
||||
if isinstance(e, SubAccess):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, SubField):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, SubIndex):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, Reference):
|
||||
return e.name
|
||||
|
||||
def expand_sequential_s(s: Statement, stmts: List[Statement], reg_map: Dict[str, DefRegister]):
|
||||
if isinstance(s, Conditionally):
|
||||
conseq_seq_map, conseq_com = expand_sequential_s(s.conseq, stmts, reg_map)
|
||||
alt_seq_map, alt_com = expand_sequential_s(s.alt, stmts, reg_map)
|
||||
for k in conseq_seq_map:
|
||||
if k not in alt_seq_map:
|
||||
alt_seq_map[k] = EmptyStmt()
|
||||
com = Conditionally(s.pred, conseq_com, alt_com, s.info)
|
||||
if isinstance(conseq_com, EmptyStmt):
|
||||
if isinstance(alt_com, EmptyStmt):
|
||||
com = EmptyStmt()
|
||||
else:
|
||||
com = Conditionally(DoPrim(Not(), [s.pred], [], s.pred.typ))
|
||||
return {k: Conditionally(s.pred, v, alt_seq_map[k], s.info) for k, v in conseq_seq_map.items()}, com
|
||||
elif isinstance(s, Block):
|
||||
com_stmts: List[Statement] = []
|
||||
seq_stmts_map: Dict[str, List[Statement]] = {}
|
||||
for sx in s.stmts:
|
||||
if isinstance(sx, Connect) and get_ref_name(sx.loc) in reg_map:
|
||||
reg = reg_map[get_ref_name(sx.loc)]
|
||||
if reg.clock.verilog_serialize() not in clock_map:
|
||||
clock_map[reg.clock.verilog_serialize()] = reg.clock
|
||||
if reg.clock.verilog_serialize() not in seq_stmts_map:
|
||||
seq_stmts_map[reg.clock.verilog_serialize()] = []
|
||||
seq_stmts_map[reg.clock.verilog_serialize()].append(Connect(sx.loc, sx.expr, sx.info, False, sx.bidirection, sx.mem))
|
||||
elif isinstance(sx, Connect) and get_ref_name(sx.loc) not in reg_map:
|
||||
com_stmts.append(sx)
|
||||
elif isinstance(sx, Conditionally):
|
||||
seq_when_map, com_when = expand_sequential_s(sx, stmts, reg_map)
|
||||
for k in seq_when_map:
|
||||
if k not in seq_stmts_map:
|
||||
seq_stmts_map[k] = []
|
||||
seq_stmts_map[k].append(seq_when_map[k])
|
||||
com_stmts.append(com_when)
|
||||
else:
|
||||
stmts.append(sx)
|
||||
return {k: Block(v) if len(v) > 0 else EmptyStmt() for k, v in seq_stmts_map.items()}, \
|
||||
Block(com_stmts) if len(com_stmts) > 0 else EmptyStmt()
|
||||
else:
|
||||
return {}, s
|
||||
|
||||
def expand_sequential(stmts: List[Statement]) -> List[Statement]:
|
||||
reg_map: Dict[str, DefRegister] = {sx.name: sx for sx in stmts if isinstance(sx, DefRegister)}
|
||||
mem_map: Dict[str, DefMemPort] = {sx.name: sx for sx in stmts if isinstance(sx, DefMemPort)}
|
||||
new_stmts: List[Statement] = []
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, Conditionally):
|
||||
seq_map, com = expand_sequential_s(stmt, new_stmts, reg_map)
|
||||
if not isinstance(com, EmptyStmt):
|
||||
new_stmts.append(com)
|
||||
for k in seq_map:
|
||||
if k not in block_map:
|
||||
block_map[k] = []
|
||||
if not isinstance(seq_map[k], EmptyStmt):
|
||||
block_map[k].append(seq_map[k])
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
reset_map: Dict[str, List[Statement]] = {}
|
||||
reset_sign_map: Dict[str, List[Expression]] = {}
|
||||
for reg_name in reg_map:
|
||||
reg: DefRegister = reg_map[reg_name]
|
||||
if reg.init is not None:
|
||||
if reg.clock.verilog_serialize() not in reset_map:
|
||||
reset_map[reg.clock.verilog_serialize()] = []
|
||||
reset_map[reg.clock.verilog_serialize()].append(Connect(Reference(reg.name, reg.typ), reg.init, reg.info, False))
|
||||
if reg.clock.verilog_serialize not in reset_sign_map:
|
||||
reset_sign_map[reg.clock.verilog_serialize()] = []
|
||||
reset_sign_map[reg.clock.verilog_serialize()].append(reg.reset)
|
||||
for k in reset_map:
|
||||
if len(reset_map[k]) > 0:
|
||||
for rs in reset_sign_map[k]:
|
||||
block_map[k].append(Conditionally(rs, Block(reset_map[k]), EmptyStmt()))
|
||||
for k in mem_map:
|
||||
mem: Reference = mem_map[k].mem
|
||||
sig = DoPrim(And(), [Reference(f"{mem.name}_{k}_en", UIntType(IntWidth(1))),
|
||||
Reference(f"{mem.name}_{k}_mask", UIntType(IntWidth(1)))], [], UIntType(IntWidth(1)))
|
||||
con = Connect(SubAccess(mem, Reference(f"{mem.name}_{k}_addr", mem_map[k].index.typ), mem.typ),
|
||||
Reference(f"{mem.name}_{k}_data", mem.typ),mem_map[k].info, False)
|
||||
if mem_map[k].rw is False:
|
||||
if mem_map[k].clk.verilog_serialize() not in block_map:
|
||||
block_map[mem_map[k].clk.verilog_serialize()] = []
|
||||
if mem_map[k].clk.verilog_serialize() not in clock_map:
|
||||
clock_map[mem_map[k].clk.verilog_serialize()] = mem_map[k].clk
|
||||
block_map[mem_map[k].clk.verilog_serialize()].append(Conditionally(sig, Block([con]), EmptyStmt()))
|
||||
for k in block_map:
|
||||
new_stmts.append(AlwaysBlock(block_map[k], clock_map[k]))
|
||||
return new_stmts
|
||||
|
||||
def expand_sequential_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, Block(expand_sequential(m.body.stmts)), m.typ)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(expand_sequential_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,70 @@
|
|||
from typing import List
|
||||
|
||||
from numpy import isin
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from pyhcl.passes.utils import AutoName
|
||||
|
||||
@dataclass
|
||||
class ExpandWhens(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
modules: List[DefModule] = []
|
||||
|
||||
def auto_gen_name():
|
||||
return AutoName.auto_gen_name()
|
||||
|
||||
def last_name():
|
||||
return AutoName.last_name()
|
||||
|
||||
def expand_whens(s: Statement, stmts: List[Statement], refs: Dict[str, List[Statement]], pred: Expression = None):
|
||||
if isinstance(s, Conditionally):
|
||||
expand_whens(s.conseq, stmts, refs, s.pred)
|
||||
expand_whens(s.alt, stmts, refs, DoPrim(Not(), [s.pred], [], s.pred.typ))
|
||||
elif isinstance(s, Block):
|
||||
for sx in s.stmts:
|
||||
expand_whens(sx, stmts, refs, pred)
|
||||
elif isinstance(s, EmptyStmt):
|
||||
...
|
||||
elif isinstance(s, Connect):
|
||||
if s.loc.serialize() not in refs:
|
||||
refs[s.loc.serialize()] = []
|
||||
refs[s.loc.serialize()].append(Conditionally(pred, Block([s]), EmptyStmt()))
|
||||
else:
|
||||
stmts.append(s)
|
||||
|
||||
def expand_whens_s(ss: List[Statement]) -> List[Statement]:
|
||||
stmts: List[Statement] = []
|
||||
refs: Dict[str, List[Statement]] = {}
|
||||
for sx in ss:
|
||||
if isinstance(sx, Conditionally):
|
||||
expand_whens(sx, stmts, refs)
|
||||
else:
|
||||
stmts.append(sx)
|
||||
for sx in refs.values():
|
||||
if len(sx) <= 1:
|
||||
sxx = sx.pop()
|
||||
con = sxx.conseq.stmts.pop()
|
||||
stmts.append(Connect(con.loc, ValidIf(sxx.pred, con.expr, con.expr.typ)))
|
||||
else:
|
||||
sxx = sx.pop()
|
||||
con = sxx.conseq.stmts.pop()
|
||||
stmts.append(DefNode(auto_gen_name(), ValidIf(sxx.pred, con.expr, con.expr.typ)))
|
||||
while len(sx) > 1:
|
||||
sxx = sx.pop()
|
||||
con = sxx.conseq.stmts.pop()
|
||||
stmts.append(DefNode(auto_gen_name(), Mux(sxx.pred, con.expr, Reference(AutoName.last_name(), con.expr.typ), con.expr.typ)))
|
||||
sxx = sx.pop()
|
||||
con = sxx.conseq.stmts.pop()
|
||||
stmts.append(Connect(con.loc, Mux(sxx.pred, con.expr, Reference(last_name(), con.expr.typ), con.expr.typ)))
|
||||
return stmts
|
||||
|
||||
def expand_whens_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, Block(expand_whens_s(m.body.stmts)), m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(expand_whens_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,33 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Dict
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
|
||||
@dataclass
|
||||
class HandleInstance(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
modules: List[DefModule] = []
|
||||
refs: Dict[str, List[Port]] = {m.name: m.ports for m in c.modules}
|
||||
|
||||
def handle_instance_s(s: Statement):
|
||||
if isinstance(s, Conditionally):
|
||||
return Conditionally(s.pred, handle_instance_s(s.conseq), handle_instance_s(s.alt), s.info)
|
||||
elif isinstance(s, Block):
|
||||
return Block([handle_instance_s(sx) for sx in s.stmts])
|
||||
elif isinstance(s, DefInstance):
|
||||
if s.module in refs:
|
||||
return DefInstance(s.name, s.module, refs[s.module], s.info)
|
||||
else:
|
||||
return s
|
||||
else:
|
||||
return s
|
||||
|
||||
def handle_instance_m(m: DefModule):
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, handle_instance_s(m.body), m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(handle_instance_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,57 @@
|
|||
from typing import List, Dict
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
|
||||
@dataclass
|
||||
class Optimize(Pass):
|
||||
def run(self, c: Circuit):
|
||||
def get_name(e: Expression) -> str:
|
||||
if isinstance(e, (SubField, SubIndex, SubAccess)):
|
||||
return get_name(e.expr)
|
||||
else:
|
||||
return e.name
|
||||
|
||||
def optimize_s(stmts: List[Statement]) -> List[Statement]:
|
||||
defwires: Dict[str, Statement] = {}
|
||||
connects: Dict[str, Statement] = {}
|
||||
nodes: Dict[str, Statement] = {}
|
||||
new_stmts: List[Statement] = []
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, DefWire):
|
||||
defwires[stmt.name] = stmt
|
||||
if isinstance(stmt, Connect):
|
||||
connects[get_name(stmt.loc)] = stmt
|
||||
|
||||
|
||||
for defwire in defwires.keys():
|
||||
if defwire in connects:
|
||||
nodes[defwire] = DefNode(defwire, connects[defwire].expr)
|
||||
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, DefWire) and stmt.name in nodes:
|
||||
new_stmts.append(nodes[stmt.name])
|
||||
elif isinstance(stmt, Connect) and get_name(stmt.loc) in nodes:
|
||||
...
|
||||
else:
|
||||
new_stmts.append(stmt)
|
||||
|
||||
return new_stmts
|
||||
|
||||
def optimize_m(m: DefModule) -> DefModule:
|
||||
return Module(
|
||||
m.name,
|
||||
m.ports,
|
||||
Block(optimize_s(m.body.stmts)),
|
||||
m.typ,
|
||||
m.info
|
||||
)
|
||||
|
||||
new_modules = []
|
||||
for m in c.modules:
|
||||
if isinstance(m, Module):
|
||||
new_modules.append(optimize_m(m))
|
||||
else:
|
||||
new_modules.append(m)
|
||||
|
||||
return Circuit(new_modules, c.main, c.info)
|
|
@ -0,0 +1,66 @@
|
|||
from pyhcl.ir.low_ir import *
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from pyhcl.passes._pass import Pass
|
||||
|
||||
@dataclass
|
||||
class RemoveAccess(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
modules: List[Module] = []
|
||||
|
||||
def remove_access(e: Expression, type: Type = None, name: str = None) -> Expression:
|
||||
if isinstance(e, Reference):
|
||||
return Reference(f"{e.name}_{name}", type)
|
||||
elif isinstance(e, SubIndex):
|
||||
if name is None:
|
||||
return remove_access(e.expr, type, f"{e.value}")
|
||||
else:
|
||||
return remove_access(e.expr, type, f"{e.value}_{name}")
|
||||
elif isinstance(e, SubField):
|
||||
if name is None:
|
||||
return remove_access(e.expr, type, f"{e.name}")
|
||||
else:
|
||||
return remove_access(e.expr, type, f"{e.name}_{name}")
|
||||
else:
|
||||
return e
|
||||
|
||||
def remove_access_e(e: Expression) -> Expression:
|
||||
if isinstance(e, (SubIndex, SubField)):
|
||||
return remove_access(e, e.typ)
|
||||
elif isinstance(e, ValidIf):
|
||||
return ValidIf(remove_access_e(e.cond), remove_access_e(e.value), e.typ)
|
||||
elif isinstance(e, Mux):
|
||||
return Mux(remove_access_e(e.cond), remove_access_e(e.tval), remove_access_e(e.fval), e.typ)
|
||||
elif isinstance(e, DoPrim):
|
||||
return DoPrim(e.op, [remove_access_e(arg) for arg in e.args], e.consts, e.typ)
|
||||
else:
|
||||
return e
|
||||
|
||||
def remove_access_s(s: Statement) -> Statement:
|
||||
if isinstance(s, Block):
|
||||
stmts: List[Statement] = []
|
||||
for sx in s.stmts:
|
||||
stmts.append(remove_access_s(sx))
|
||||
return Block(stmts)
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(remove_access_e(s.pred), remove_access_s(s.conseq), remove_access_s(s.alt), s.info)
|
||||
elif isinstance(s, DefRegister):
|
||||
return DefRegister(s.name, s.typ, remove_access_e(s.clock), remove_access_e(s.reset), remove_access_e(s.init), s.info)
|
||||
elif isinstance(s, DefNode):
|
||||
return DefNode(s.name, remove_access_e(s.value), s.info)
|
||||
elif isinstance(s, DefMemPort):
|
||||
return DefMemPort(s.name, s.mem, remove_access_e(s.index), remove_access_e(s.clk), s.rw, s.info)
|
||||
elif isinstance(s, Connect):
|
||||
return Connect(remove_access_e(s.loc), remove_access_e(s.expr), s.info)
|
||||
else:
|
||||
return s
|
||||
|
||||
def remove_access_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, remove_access_s(m.body), m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(remove_access_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,140 @@
|
|||
from sqlite3 import connect
|
||||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from pyhcl.passes.utils import get_binary_width, AutoName
|
||||
|
||||
@dataclass
|
||||
class ReplaceSubaccess(Pass):
|
||||
def run(self, c: Circuit) -> Circuit:
|
||||
modules: List[DefModule] = []
|
||||
|
||||
def has_access(e: Expression) -> bool:
|
||||
if isinstance(e, SubAccess):
|
||||
return True
|
||||
elif isinstance(e, (SubField, SubIndex)):
|
||||
return has_access(e.expr)
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_ref_name(e: Expression) -> str:
|
||||
if isinstance(e, SubAccess):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, SubField):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, SubIndex):
|
||||
return get_ref_name(e.expr)
|
||||
elif isinstance(e, Reference):
|
||||
return e.name
|
||||
|
||||
def auto_gen_name():
|
||||
return AutoName.auto_gen_name()
|
||||
|
||||
def last_name():
|
||||
return AutoName.last_name()
|
||||
|
||||
def replace_subaccess(e: Expression):
|
||||
cons: List[Expression] = []
|
||||
exps: List[Expression] = []
|
||||
if isinstance(e, SubAccess):
|
||||
xcons, xexps = replace_subaccess(e.expr)
|
||||
if len(cons) == 0 and len(xexps) == 0:
|
||||
if isinstance(e.expr.typ, VectorType):
|
||||
for i in range(e.expr.typ.size):
|
||||
cons.append(DoPrim(Eq(), [e.index, UIntLiteral(i, IntWidth(get_binary_width(e.expr.typ.size)))], [], UIntType(IntWidth(1))))
|
||||
exps.append(SubIndex("", e.expr, i, e.typ))
|
||||
else:
|
||||
exps.append(e)
|
||||
else:
|
||||
if isinstance(e.expr.typ, VectorType):
|
||||
for i in range(e.expr.typ.size):
|
||||
for xcon in xcons:
|
||||
cons.append(DoPrim(And(), [xcon, DoPrim(Eq(), [e.index, UIntLiteral(i, IntWidth(get_binary_width(e.expr.typ.size)))],
|
||||
[], UIntType(IntWidth(1)))], [], UIntType(IntWidth(1))))
|
||||
for xexp in xexps:
|
||||
exps.append(SubIndex("", xexp, i, e.typ))
|
||||
else:
|
||||
cons, exps = xcons, xexps
|
||||
elif isinstance(e, SubField):
|
||||
xcons, xexps = replace_subaccess(e.expr)
|
||||
cons = xcons
|
||||
for xexp in xexps:
|
||||
exps.append(SubField(xexp, e.name, e.typ))
|
||||
elif isinstance(e, SubIndex):
|
||||
xcons, xexps = replace_subaccess(e.expr)
|
||||
cons = xcons
|
||||
for xexp in xexps:
|
||||
exps.append(SubIndex("", xexp, e.value, e.typ))
|
||||
|
||||
return cons, exps
|
||||
|
||||
def replace_subaccess_e(e: Expression, stmts: List[Statement], is_sink: bool = False, source: Expression = None) -> Expression:
|
||||
if isinstance(e, ValidIf):
|
||||
return ValidIf(replace_subaccess_e(e.cond, stmts), replace_subaccess_e(e.value, stmts), e.typ)
|
||||
elif isinstance(e, Mux):
|
||||
return Mux(replace_subaccess_e(e.cond, stmts), replace_subaccess_e(e.tval, stmts), replace_subaccess_e(e.fval, stmts), e.typ)
|
||||
elif isinstance(e, DoPrim):
|
||||
return DoPrim(e.op, [replace_subaccess_e(arg, stmts) for arg in e.args], e.consts, e.typ)
|
||||
elif isinstance(e, (SubAccess, SubField, SubIndex)) and has_access(e):
|
||||
if is_sink:
|
||||
cons, exps = replace_subaccess(e)
|
||||
gen_nodes: Dict[str, DefNode] = {}
|
||||
for i in range(len(cons)):
|
||||
stmts.append(Connect(exps[i], Mux(cons[i], source, exps[i], e.typ)))
|
||||
return
|
||||
else:
|
||||
cons, exps = replace_subaccess(e)
|
||||
gen_nodes: Dict[str, DefNode] = {}
|
||||
for i in range(len(cons)):
|
||||
if i == 0:
|
||||
name = auto_gen_name()
|
||||
gen_node = DefNode(name, ValidIf(cons[i], exps[i], e.typ))
|
||||
gen_nodes[name] = gen_node
|
||||
else:
|
||||
last_node = gen_nodes[last_name()]
|
||||
name = auto_gen_name()
|
||||
gen_node = DefNode(name, Mux(cons[i], exps[i], last_node.value.value if i == 1 else Reference(last_node.name, last_node.value.typ), e.typ))
|
||||
stmts.append(gen_node)
|
||||
gen_nodes[name] = gen_node
|
||||
return Reference(gen_nodes[last_name()].name, e.typ)
|
||||
else:
|
||||
return e
|
||||
|
||||
def replace_subaccess_s(s: Statement) -> Statement:
|
||||
if isinstance(s, Block):
|
||||
stmts: List[Statement] = []
|
||||
for stmt in s.stmts:
|
||||
if isinstance(stmt, Connect):
|
||||
expr = replace_subaccess_e(stmt.expr, stmts)
|
||||
loc = replace_subaccess_e(stmt.loc, stmts, True, expr)
|
||||
if loc is not None:
|
||||
stmts.append(Connect(loc, expr, stmt.info, stmt.blocking, stmt.bidirection, stmt.mem))
|
||||
else:
|
||||
...
|
||||
# stmts.append(stmt)
|
||||
elif isinstance(stmt, DefNode):
|
||||
stmts.append(DefNode(stmt.name, replace_subaccess_e(stmt.value, stmts), stmt.info))
|
||||
elif isinstance(stmt, DefRegister):
|
||||
stmts.append(DefRegister(stmt.name, stmt.typ, stmt.clock, stmt.reset, replace_subaccess_e(stmt.init, stmts), stmt.info))
|
||||
elif isinstance(stmt, Conditionally):
|
||||
stmts.append(replace_subaccess_s(stmt))
|
||||
else:
|
||||
stmts.append(stmt)
|
||||
return Block(stmts)
|
||||
elif isinstance(s, EmptyStmt):
|
||||
return EmptyStmt()
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(s.pred, replace_subaccess_s(s.conseq), replace_subaccess_s(s.alt), s.info)
|
||||
else:
|
||||
return s
|
||||
|
||||
def replace_subaccess_m(m: DefModule) -> DefModule:
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, replace_subaccess_s(m.body), m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(replace_subaccess_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,252 @@
|
|||
import math
|
||||
|
||||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.passes.wir import DuplexFlow, Flow, SinkFlow, SourceFlow, UnknownFlow
|
||||
from pyhcl.passes._pass import PassException
|
||||
|
||||
# CheckForm utils
|
||||
class ModuleGraph:
|
||||
nodes: Dict[str, set] = {}
|
||||
|
||||
def add(self, parent: str, child: str) -> List[str]:
|
||||
if parent not in self.nodes.keys():
|
||||
self.nodes[parent] = set()
|
||||
self.nodes[parent].add(child)
|
||||
return self.path_exists(child, parent, [child, parent])
|
||||
|
||||
def path_exists(self, child: str, parent: str, path: List[str] = None) -> List[str]:
|
||||
cx = self.nodes[child] if child in self.nodes.keys() else None
|
||||
if cx is not None:
|
||||
if parent in cx:
|
||||
return [parent] + path
|
||||
else:
|
||||
for cxx in cx:
|
||||
new_path = self.path_exists(cxx, parent, [cxx] + path)
|
||||
if len(new_path) > 0:
|
||||
return new_path
|
||||
return None
|
||||
|
||||
|
||||
def is_max(w1: Width, w2: Width):
|
||||
return IntWidth(max(w1.width, w2.width))
|
||||
|
||||
def to_flow(d: Direction) -> Flow:
|
||||
if isinstance(d, Input):
|
||||
return SourceFlow()
|
||||
if isinstance(d, Output):
|
||||
return SinkFlow()
|
||||
|
||||
def flow(e: Expression) -> Flow:
|
||||
if isinstance(e, DoPrim):
|
||||
return SourceFlow()
|
||||
elif isinstance(e, UIntLiteral):
|
||||
return SourceFlow()
|
||||
elif isinstance(e, SIntLiteral):
|
||||
return SourceFlow()
|
||||
elif isinstance(e, Mux):
|
||||
return SourceFlow()
|
||||
elif isinstance(e, ValidIf):
|
||||
return SourceFlow()
|
||||
else:
|
||||
raise PassException(f'flow: shouldn\'t be here - {e}')
|
||||
|
||||
def mux_type_and_widths(e1, e2) -> Type:
|
||||
return mux_type_and_width(e1.typ, e2.typ)
|
||||
|
||||
def mux_type_and_width(t1: Type, t2: Type) -> Type:
|
||||
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
|
||||
return ClockType()
|
||||
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
|
||||
return AsyncResetType()
|
||||
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
|
||||
return UIntType(is_max(t1.width, t2.width))
|
||||
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
|
||||
return VectorType(mux_type_and_width(t1.typ, t2.typ), t1.size)
|
||||
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
|
||||
return BundleType(map(lambda f1, f2: Field(f1.name, f1.flip, mux_type_and_width(f1.typ, f2.typ)), list(zip(t1.fields, t2.fields))))
|
||||
else:
|
||||
return UnknownType
|
||||
|
||||
def create_exps(e: Expression) -> List[Expression]:
|
||||
if isinstance(e, Mux):
|
||||
e1s = create_exps(e.tval)
|
||||
e2s = create_exps(e.fval)
|
||||
return list(map(lambda e1, e2: Mux(e.cond, e1, e2, mux_type_and_widths(e1, e2)), list(zip(e1s, e2s))))
|
||||
elif isinstance(e, ValidIf):
|
||||
return list(map(lambda e1: ValidIf(e.cond, e1, e1.typ), create_exps(e.value)))
|
||||
else:
|
||||
if isinstance(e.typ, GroundType):
|
||||
return [e]
|
||||
elif isinstance(e.typ, BundleType):
|
||||
exprs = []
|
||||
for f in e.typ.fields:
|
||||
exprs = exprs + create_exps(Reference(f'{e.name}_{f.name}', f.typ))
|
||||
return exprs
|
||||
elif isinstance(e.typ, VectorType):
|
||||
exprs = []
|
||||
for i in range(e.typ.size):
|
||||
exprs = exprs + create_exps(Reference(f'{e.name}_{i}', f.typ))
|
||||
return exprs
|
||||
|
||||
def get_info(s: Statement) -> Info:
|
||||
if hasattr(s, 'info'):
|
||||
return s.info
|
||||
else:
|
||||
return NoInfo
|
||||
|
||||
def has_flip(typ: Type) -> bool:
|
||||
if isinstance(typ, BundleType):
|
||||
for f in typ.fields:
|
||||
if isinstance(f, Flip):
|
||||
return True
|
||||
else:
|
||||
return has_flip(f.typ)
|
||||
elif isinstance(typ, VectorType):
|
||||
return has_flip(typ.typ)
|
||||
else:
|
||||
return False
|
||||
|
||||
# InterTypes utils
|
||||
|
||||
def to_flip(d: Direction):
|
||||
if isinstance(d, Output):
|
||||
return Default()
|
||||
if isinstance(d, Input):
|
||||
return Flip()
|
||||
|
||||
def module_type(m: DefModule) -> BundleType:
|
||||
fields = [Field(p.name, to_flip(p.direction), p.typ) for p in m.ports]
|
||||
return BundleType(fields)
|
||||
|
||||
def field_type(v: Type, s: str) -> Type:
|
||||
if isinstance(v, BundleType):
|
||||
def match_type(f: Field) -> Type:
|
||||
if f is None:
|
||||
return UnknownType
|
||||
return f.typ
|
||||
|
||||
for f in v.fields:
|
||||
if f.name == s:
|
||||
return match_type(f)
|
||||
return UnknownType()
|
||||
|
||||
def sub_type(v: Type) -> Type:
|
||||
if isinstance(v, VectorType):
|
||||
return v.typ
|
||||
return UnknownType()
|
||||
|
||||
def mux_type(e1: Expression, e2: Expression) -> Type:
|
||||
return mux_types(e1.typ, e2.typ)
|
||||
|
||||
def mux_types(t1: Type, t2: Type) -> Type:
|
||||
if isinstance(t1, ClockType) and isinstance(t2, ClockType):
|
||||
return ClockType()
|
||||
elif isinstance(t1, AsyncResetType) and isinstance(t2, AsyncResetType):
|
||||
return AsyncResetType()
|
||||
elif isinstance(t1, UIntType) and isinstance(t2, UIntType):
|
||||
return t1
|
||||
elif isinstance(t1, SIntType) and isinstance(t2, SIntType):
|
||||
return t2
|
||||
elif isinstance(t1, VectorType) and isinstance(t2, VectorType):
|
||||
return VectorType(mux_types(t1.typ, t2.typ), t1.size)
|
||||
elif isinstance(t1, BundleType) and isinstance(t2, BundleType):
|
||||
return BundleType(list(map(lambda f1, f2: Field(f1.name, f1.flip, mux_types(f1.typ, f2.typ)), list(zip(t1.fields, t2.fields)))))
|
||||
else:
|
||||
return UnknownType()
|
||||
|
||||
def get_or_else(cond, a, b):
|
||||
return a if cond else b
|
||||
|
||||
# CheckTypes utils
|
||||
def swp_flow(f: Flow) -> Flow:
|
||||
if isinstance(f, UnknownFlow):
|
||||
return UnknownFlow()
|
||||
elif isinstance(f, SourceFlow):
|
||||
return SinkFlow()
|
||||
elif isinstance(f, SinkFlow):
|
||||
return SourceFlow()
|
||||
elif isinstance(f, DuplexFlow):
|
||||
return DuplexFlow()
|
||||
else:
|
||||
return Flow()
|
||||
|
||||
def swp_direction(d: Direction) -> Direction:
|
||||
if isinstance(d, Input):
|
||||
return Output()
|
||||
elif isinstance(d, Output):
|
||||
return Input()
|
||||
else:
|
||||
return Direction()
|
||||
|
||||
def swp_orientation(o: Orientation) -> Orientation:
|
||||
if isinstance(o, Default):
|
||||
return Default()
|
||||
elif isinstance(o, Flip):
|
||||
return Flip()
|
||||
else:
|
||||
return Orientation()
|
||||
|
||||
def times_d_flip(d: Direction, flip: Orientation) -> Direction:
|
||||
if isinstance(flip, Default):
|
||||
return d
|
||||
elif isinstance(flip, Flip):
|
||||
return swp_direction(d)
|
||||
|
||||
def times_g_d(g: Flow, d: Direction) -> Direction:
|
||||
return times_d_g(d, g)
|
||||
|
||||
def times_d_g(d: Direction, g: Flow) -> Direction:
|
||||
if isinstance(g, SinkFlow):
|
||||
return d
|
||||
elif isinstance(g, SourceFlow):
|
||||
return swp_flow(d)
|
||||
|
||||
def times_g_flip(g: Flow, flip: Orientation) -> Flow:
|
||||
return times_flip_g(flip, g)
|
||||
|
||||
def times_flip_g(flip: Orientation, g: Flow) -> Flow:
|
||||
if isinstance(flip, Default):
|
||||
return g
|
||||
elif isinstance(flip, Flip):
|
||||
return swp_flow(g)
|
||||
|
||||
def times_f_f(f1: Orientation, f2: Orientation) -> Orientation:
|
||||
if isinstance(f2, Default):
|
||||
return f1
|
||||
elif isinstance(f2, Flip):
|
||||
return swp_orientation(f1)
|
||||
|
||||
# CheckWidth Utils
|
||||
def get_binary_width(target):
|
||||
width = 1
|
||||
while target / 2 >= 1:
|
||||
width += 1
|
||||
target = math.floor(target / 2)
|
||||
return width
|
||||
|
||||
def get_width(w: Width) -> int:
|
||||
if isinstance(w, UnknownWidth):
|
||||
return 0
|
||||
return w.width
|
||||
|
||||
def has_width(t: Type) -> bool:
|
||||
if hasattr(t, 'width'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class AutoName:
|
||||
endwith: int = -1
|
||||
names: List[str] = []
|
||||
|
||||
@staticmethod
|
||||
def auto_gen_name():
|
||||
AutoName.endwith += 1
|
||||
gen_name = f"GEN_{AutoName.endwith}"
|
||||
AutoName.names.append(gen_name)
|
||||
return gen_name
|
||||
|
||||
@staticmethod
|
||||
def last_name():
|
||||
return AutoName.names[-1]
|
|
@ -0,0 +1,86 @@
|
|||
from dataclasses import dataclass
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.passes._pass import Pass
|
||||
from typing import List, Dict
|
||||
from pyhcl.passes.utils import AutoName
|
||||
|
||||
@dataclass
|
||||
class VerilogOptimize(Pass):
|
||||
def run(self, c: Circuit):
|
||||
modules: List[DefModule] = []
|
||||
|
||||
def auto_gen_node(s):
|
||||
return isinstance(s, DefNode) and s.name.startswith("_T")
|
||||
|
||||
def get_name(e: Expression) -> str:
|
||||
if isinstance(e, (SubAccess, SubField, SubIndex)):
|
||||
return get_name(e.expr)
|
||||
elif isinstance(e, Reference):
|
||||
return e.verilog_serialize()
|
||||
|
||||
def verilog_optimize_e(expr: Expression, node_map: Dict[str, Statement], filter_nodes: set, stmts: List[Statement]) -> Expression:
|
||||
if isinstance(expr, (UIntLiteral, SIntLiteral)):
|
||||
return expr
|
||||
elif isinstance(expr, Reference):
|
||||
en = get_name(expr)
|
||||
if en in node_map:
|
||||
filter_nodes.add(en)
|
||||
return verilog_optimize_e(node_map[en].value, node_map, filter_nodes, stmts)
|
||||
else:
|
||||
return expr
|
||||
elif isinstance(expr, (SubField, SubIndex, SubAccess)):
|
||||
return expr
|
||||
elif isinstance(expr, Mux):
|
||||
return Mux(
|
||||
verilog_optimize_e(expr.cond, node_map, filter_nodes, stmts),
|
||||
verilog_optimize_e(expr.tval, node_map, filter_nodes, stmts),
|
||||
verilog_optimize_e(expr.fval, node_map, filter_nodes, stmts), expr.typ)
|
||||
elif isinstance(expr, ValidIf):
|
||||
return ValidIf(
|
||||
verilog_optimize_e(expr.cond, node_map, filter_nodes, stmts),
|
||||
verilog_optimize_e(expr.value, node_map, filter_nodes, stmts), expr.typ)
|
||||
elif isinstance(expr, DoPrim):
|
||||
args = list(map(lambda arg: verilog_optimize_e(arg, node_map, filter_nodes, stmts), expr.args))
|
||||
if isinstance(expr.op, Bits) and isinstance(args[0], DoPrim):
|
||||
name = AutoName.auto_gen_name()
|
||||
stmts.append(DefNode(name, args[0]))
|
||||
args = [Reference(name, args[0].typ)]
|
||||
return DoPrim(expr.op, args, expr.consts, expr.typ)
|
||||
else:
|
||||
return expr
|
||||
|
||||
def verilog_optimize_s(stmt: Statement, node_map: Dict[str, Statement], filter_nodes: set, stmts: List[Statement] = None) -> Statement:
|
||||
if isinstance(stmt, Block):
|
||||
node_map = {**node_map ,**{sx.name: sx for sx in stmt.stmts if auto_gen_node(sx)}}
|
||||
cat_stmts = []
|
||||
for sx in stmt.stmts:
|
||||
if isinstance(sx, Connect):
|
||||
cat_stmts.append(Connect(verilog_optimize_e(sx.loc, node_map, filter_nodes, cat_stmts),
|
||||
verilog_optimize_e(sx.expr, node_map, filter_nodes, cat_stmts), sx.info, sx.blocking, sx.bidirection, sx.mem))
|
||||
elif isinstance(sx, DefNode):
|
||||
cat_stmts.append(DefNode(sx.name, verilog_optimize_e(sx.value, node_map, filter_nodes, cat_stmts), sx.info))
|
||||
elif isinstance(sx, Conditionally):
|
||||
cat_stmts.append(Conditionally(verilog_optimize_e(sx.pred, node_map, filter_nodes, cat_stmts), verilog_optimize_s(sx.conseq, node_map, filter_nodes, cat_stmts),
|
||||
verilog_optimize_s(sx.alt, node_map, filter_nodes, cat_stmts), sx.info))
|
||||
else:
|
||||
cat_stmts.append(sx)
|
||||
cat_stmts = [sx for sx in cat_stmts if not (isinstance(sx, DefNode) and sx.name in filter_nodes)]
|
||||
return Block(cat_stmts)
|
||||
elif isinstance(stmt, Conditionally):
|
||||
return Conditionally(verilog_optimize_e(stmt.pred, node_map, filter_nodes, stmts), verilog_optimize_s(stmt.conseq, node_map, filter_nodes, stmts),
|
||||
verilog_optimize_s(stmt.alt, node_map, filter_nodes, stmts), stmt.info)
|
||||
else:
|
||||
return stmt
|
||||
|
||||
|
||||
def verilog_optimize_m(m: DefModule) -> DefModule:
|
||||
node_map: Dict[str, DefNode] = {}
|
||||
filter_nodes: set = set()
|
||||
if isinstance(m, Module):
|
||||
return Module(m.name, m.ports, verilog_optimize_s(m.body, node_map, filter_nodes), m.typ, m.info)
|
||||
else:
|
||||
return m
|
||||
|
||||
for m in c.modules:
|
||||
modules.append(verilog_optimize_m(m))
|
||||
return Circuit(modules, c.main, c.info)
|
|
@ -0,0 +1,81 @@
|
|||
from abc import ABC
|
||||
from pyhcl.ir.low_ir import *
|
||||
|
||||
class Flow(ABC):
|
||||
...
|
||||
|
||||
class SourceFlow(Flow):
|
||||
...
|
||||
|
||||
class SinkFlow(Flow):
|
||||
...
|
||||
|
||||
class DuplexFlow(Flow):
|
||||
...
|
||||
|
||||
class UnknownFlow(Flow):
|
||||
...
|
||||
|
||||
class Kind(ABC):
|
||||
...
|
||||
|
||||
class PortKind(Kind):
|
||||
...
|
||||
|
||||
class VarWidth(Width):
|
||||
name: str
|
||||
|
||||
def serialize(self):
|
||||
return self.name
|
||||
|
||||
class WrappedType(Type):
|
||||
def __init__(self, t: Type):
|
||||
self.t = t
|
||||
|
||||
def __eq__(self, o):
|
||||
if isinstance(o, WrappedType):
|
||||
return WrappedType.compare(self.t, o.t)
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def compare(sink: Type, source: Type):
|
||||
def legal_reset_type(self, typ: Type) -> bool:
|
||||
if isinstance(typ, UIntType) and isinstance(typ.width, IntWidth):
|
||||
return typ.width.width == 1
|
||||
elif isinstance(typ, AsyncResetType):
|
||||
return True
|
||||
elif isinstance(typ, ResetType):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
if isinstance(sink, UIntType) and isinstance(source, UIntType):
|
||||
return True
|
||||
elif isinstance(sink, SIntType) and isinstance(source, SIntType):
|
||||
return True
|
||||
elif isinstance(sink, ClockType) and isinstance(source, ClockType):
|
||||
return True
|
||||
elif isinstance(sink, AsyncResetType) and isinstance(source, AsyncResetType):
|
||||
return True
|
||||
elif isinstance(sink, ResetType):
|
||||
return legal_reset_type(source)
|
||||
elif isinstance(source, ResetType):
|
||||
return legal_reset_type(sink)
|
||||
elif isinstance(sink, VectorType) and isinstance(source, VectorType):
|
||||
return sink.size == source.size and WrappedType.compare(sink.typ, source.typ)
|
||||
elif isinstance(sink, BundleType) and isinstance(source, BundleType):
|
||||
final = True
|
||||
for f1, f2 in list(zip(sink.fields, source.fields)):
|
||||
f1_final = WrappedType.compare(f2.typ, f1.typ) if isinstance(f1.flip, Flip) else WrappedType.compare(f1.typ, f2.typ)
|
||||
final = f1.flip == f2.flip and f1.name == f2.name and f1_final and final
|
||||
|
||||
return len(sink.fields) == len(source.fields) and final
|
||||
else:
|
||||
return False
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
|
@ -120,15 +120,16 @@ class Simlite(object):
|
|||
with open(f"./simulation/{self.dut_name}-harness.cpp", "w+") as f:
|
||||
f.write(harness_code)
|
||||
|
||||
# 在simulation文件夹下创建dut_name-harness.fir,写入firrtl代码
|
||||
with open(f"./simulation/{self.dut_name}.fir", "w+") as f:
|
||||
f.write(self.low_module.serialize())
|
||||
# 调用FIRRTL工具链
|
||||
# with open(f"./simulation/{self.dut_name}.fir", "w+") as f:
|
||||
# f.write(self.low_module.serialize())
|
||||
|
||||
# 调用firrtl命令,传入firrtl代码,得到verilog代码
|
||||
# firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog
|
||||
# print(f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
|
||||
os.system(
|
||||
f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
|
||||
# os.system(
|
||||
# f"firrtl -i ./simulation/{self.dut_name}.fir -o ./simulation/{self.dut_name}.v -X verilog")
|
||||
|
||||
# # 调用PyHCL编译链
|
||||
with open(f"./simulation/{self.dut_name}.v", "w+") as f:
|
||||
f.write(Verilog(self.low_module).emit())
|
||||
|
||||
vfn = "{}.v".format(self.dut_name) # {dut_name}.v
|
||||
hfn = "{}-harness.cpp".format(self.dut_name) # {dut_name}-harness.cpp
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
from abc import ABC, abstractclassmethod
|
||||
from typing import List
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.tester.symbol_table import SymbolTable
|
||||
|
||||
class ClockStepper(ABC):
|
||||
@abstractclassmethod
|
||||
def bump_clock(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def run(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def get_cycle_count(self):
|
||||
...
|
||||
|
||||
@abstractclassmethod
|
||||
def combinational_bump(self):
|
||||
...
|
||||
|
||||
class SingleClockStepper(ClockStepper):
|
||||
def __init__(self, mname: str, symbol: str, executor, table: SymbolTable):
|
||||
self.mname: str = mname
|
||||
self.clock_symbol: str = symbol
|
||||
self.executor = executor
|
||||
self.table: SymbolTable = table
|
||||
self.clock_cycles = 0
|
||||
self.combinational_bumps = 0
|
||||
|
||||
def handle_name(self, name):
|
||||
names = name.split(".")
|
||||
names.reverse()
|
||||
return names
|
||||
|
||||
def bump_clock(self, mname: str, clock_symbol: str, value: int):
|
||||
self.table.set_symbol_value(mname, self.handle_name(clock_symbol), value)
|
||||
self.clock_cycles += 1
|
||||
|
||||
def combinational_bump(self, value: int):
|
||||
self.combinational_bumps += value
|
||||
|
||||
def get_cycle_count(self):
|
||||
return self.clock_cycles
|
||||
|
||||
def run(self, steps: int):
|
||||
def raise_clock():
|
||||
self.table[self.mname][self.clock_symbol] ^= 1
|
||||
self.executor.execute(self.mname)
|
||||
|
||||
self.combinational_bumps = 0
|
||||
|
||||
def lower_clock():
|
||||
self.table[self.mname][self.clock_symbol] ^= 1
|
||||
self.combinational_bumps = 0
|
||||
|
||||
for _ in range(steps):
|
||||
if self.executor.get_inputchange():
|
||||
self.executor.execute(self.mname)
|
||||
self.clock_cycles += 1
|
||||
|
||||
raise_clock()
|
||||
lower_clock()
|
||||
|
||||
|
||||
class MultiClockStepper(ClockStepper):
|
||||
# TODO: Add MultiCLockStepper
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
from functools import reduce
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.tester.wir import *
|
||||
from pyhcl.tester.symbol_table import SymbolTable
|
||||
from pyhcl.tester.exception import TesterException
|
||||
from pyhcl.tester.utils import DAG
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TesterCompiler:
|
||||
symbol_table: SymbolTable
|
||||
dags = {}
|
||||
modules = {}
|
||||
|
||||
def get_in_port_name(self, name: str, t: Type, d: Direction) -> List[str]:
|
||||
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
|
||||
return [name]
|
||||
elif isinstance(d, Input) and isinstance(t, (VectorType, MemoryType)):
|
||||
names = []
|
||||
pnames = self.get_in_port_name(name, t.typ, d)
|
||||
for pn in pnames:
|
||||
for i in range(t.size):
|
||||
names.append(f"{pn}[{i}]")
|
||||
return names
|
||||
elif isinstance(t, BundleType):
|
||||
names = []
|
||||
for f in t.fields:
|
||||
pnames = []
|
||||
if isinstance(d, Input) and isinstance(f.flip, Default):
|
||||
pnames += self.get_in_port_name(f.name, f.typ, d)
|
||||
elif isinstance(d, Output) and isinstance(f.flip, Flip):
|
||||
pnames += self.get_in_port_name(f.name, f.typ, Input())
|
||||
for pn in pnames:
|
||||
names.append(f"{name}.{pn}")
|
||||
return names
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_out_port_name(self, name: str, t: Type, d: Direction) -> List[str]:
|
||||
if isinstance(d, Output) and isinstance(t, (UIntType, SIntType)):
|
||||
return [name]
|
||||
elif isinstance(d, Output) and isinstance(t, (VectorType, MemoryType)):
|
||||
names = []
|
||||
pnames = self.get_in_port_name(name, t.typ, d)
|
||||
for pn in pnames:
|
||||
for i in range(t.size):
|
||||
names.append(f"{pn}[{i}]")
|
||||
return names
|
||||
elif isinstance(t, BundleType):
|
||||
names = []
|
||||
for f in t.fields:
|
||||
pnames = []
|
||||
if isinstance(d, Output) and isinstance(f.flip, Default):
|
||||
pnames += self.get_out_port_name(f.name, f.typ, d)
|
||||
elif isinstance(d, Input) and isinstance(f.flip, Flip):
|
||||
pnames += self.get_out_port_name(f.name, f.typ, Output())
|
||||
for pn in pnames:
|
||||
names.append(f"{name}.{pn}")
|
||||
return names
|
||||
else:
|
||||
return []
|
||||
|
||||
def gen_dag_nodes(self, name: str, typ: Type):
|
||||
if isinstance(typ, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
|
||||
return [name]
|
||||
elif isinstance(typ, VectorType):
|
||||
names = []
|
||||
pre_names = self.gen_dag_nodes(name, typ.typ)
|
||||
for n in pre_names:
|
||||
for i in range(typ.size):
|
||||
names.append(f"{n}[{i}]")
|
||||
return names
|
||||
elif isinstance(typ, MemoryType):
|
||||
names = []
|
||||
pre_names = self.gen_dag_nodes(name, typ.typ)
|
||||
for n in pre_names:
|
||||
for i in range(typ.size):
|
||||
names.append(f"{n}[{i}]")
|
||||
return names
|
||||
elif isinstance(typ, BundleType):
|
||||
names = []
|
||||
for f in typ.fields:
|
||||
pre_names = self.gen_dag_nodes(f.name, f.typ)
|
||||
for n in pre_names:
|
||||
names.append(f"{name}.{n}")
|
||||
return names
|
||||
|
||||
def add_dag_node(self, mname: str, name: str, typ: Type):
|
||||
vs = self.gen_dag_nodes(name, typ)
|
||||
for v in vs:
|
||||
self.dags[mname].add_node_if_not_exists(v)
|
||||
|
||||
def add_dag_edge(self, mname: str, ind: Expression, dep: Expression):
|
||||
if isinstance(dep, (Reference, SubAccess, SubIndex, SubField)):
|
||||
self.dags[mname].add_edge(dep.serialize(), ind.serialize())
|
||||
elif isinstance(dep, Mux):
|
||||
self.add_dag_edge(mname, ind, dep.tval)
|
||||
self.add_dag_edge(mname, ind, dep.fval)
|
||||
elif isinstance(dep, ValidIf):
|
||||
self.add_dag_edge(mname, ind, dep.value)
|
||||
# elif isinstance(dep, DoPrim):
|
||||
# for arg in dep.args:
|
||||
# self.add_dag_edge(mname, ind, arg)
|
||||
|
||||
|
||||
def gen_working_ir(self, mname: str, names: list, expr: Expression) -> Expression:
|
||||
if isinstance(expr, SubField):
|
||||
names.append(expr.name)
|
||||
e = self.gen_working_ir(mname, names, expr.expr)
|
||||
get_func, set_func = e.get_func, e.set_func
|
||||
return WSubField(expr,
|
||||
lambda table=None: get_func(table),
|
||||
lambda s, table=None: set_func(s, table))
|
||||
elif isinstance(expr, SubAccess):
|
||||
index = self.gen_working_ir(mname, [], expr.index)
|
||||
index_get_func = index.get_func
|
||||
names.append(index_get_func())
|
||||
e = self.gen_working_ir(mname, names, expr.expr)
|
||||
get_func, set_func = e.get_func, e.set_func
|
||||
return WSubAccess(expr,
|
||||
lambda table=None: get_func(table),
|
||||
lambda s, table=None: set_func(s, table))
|
||||
elif isinstance(expr, SubIndex):
|
||||
# size = expr.expr.typ.size-1
|
||||
names.append(expr.value)
|
||||
e = self.gen_working_ir(mname, names, expr.expr)
|
||||
get_func, set_func = e.get_func, e.set_func
|
||||
return WSubField(expr,
|
||||
lambda table=None: get_func(table),
|
||||
lambda s, table=None: set_func(s, table))
|
||||
else:
|
||||
name = expr.serialize()
|
||||
names.append(name)
|
||||
if not self.symbol_table.has_symbol(mname, name):
|
||||
raise TesterException(f"Module [{mname}] Reference {name} is not declared")
|
||||
return WReference(expr,
|
||||
lambda table=None: self.symbol_table.get_symbol_value(mname, names, table),
|
||||
lambda s, table=None: self.symbol_table.set_symbol_value(mname, names, s, table))
|
||||
|
||||
def compile_op(self, op, args, consts):
|
||||
if isinstance(op, Add):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) + y.get_value(table), args + consts)
|
||||
elif isinstance(op, Sub):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) - y.get_value(table), args + consts)
|
||||
elif isinstance(op, Mul):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) * y.get_value(table), args + consts)
|
||||
elif isinstance(op, Div):
|
||||
return lambda table=None: int(reduce(lambda x, y: x.get_value(table) / y.get_value(table), args + consts))
|
||||
elif isinstance(op, Rem):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) % y.get_value(table), args + consts)
|
||||
elif isinstance(op, Lt):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) < y.get_value(table), args + consts)
|
||||
elif isinstance(op, Leq):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) <= y.get_value(table), args + consts)
|
||||
elif isinstance(op, Gt):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) > y.get_value(table), args + consts)
|
||||
elif isinstance(op, Geq):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) >= y.get_value(table), args + consts)
|
||||
elif isinstance(op, Eq):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) == y.get_value(table), args + consts)
|
||||
elif isinstance(op, Neq):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) != y.get_value(table), args + consts)
|
||||
elif isinstance(op, Neg):
|
||||
return lambda table=None: -(args + consts)[0].get_value(table)
|
||||
elif isinstance(op, Not):
|
||||
return lambda table=None: not (args + consts)[0].get_value(table)
|
||||
elif isinstance(op, And):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) & y.get_value(table), args + consts)
|
||||
elif isinstance(op, Or):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) | y.get_value(table), args + consts)
|
||||
elif isinstance(op, Xor):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) ^ y.get_value(table), args + consts)
|
||||
elif isinstance(op, Shl):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) << y.get_value(table), args + consts)
|
||||
elif isinstance(op, Shr):
|
||||
return lambda table=None: reduce(lambda x, y: x.get_value(table) >> y.get_value(table), args + consts)
|
||||
elif isinstance(op, Bits):
|
||||
def bits(args, consts, table=None):
|
||||
value = '{:032b}'.format(args[0].get_value(table))
|
||||
value = value[::-1]
|
||||
value_width = len(value)
|
||||
lsb = consts[0].get_value(table) if consts[0].get_value(table) < value_width else 0
|
||||
msb = consts[1].get_value(table) if consts[1].get_value(table) < value_width else 0
|
||||
value = value[msb: lsb] if lsb > msb else value[lsb]
|
||||
return int(value, 2)
|
||||
return lambda table=None: bits(args, consts, table)
|
||||
elif isinstance(op, Cat):
|
||||
def cat(args, table):
|
||||
hi = args[0].get_value(table) if isinstance(args[0].get_value(table), str) else bin(args[0].get_value(table))[2:]
|
||||
lo = args[1].get_value(table) if isinstance(args[1].get_value(table), str) else bin(args[1].get_value(table))[2:]
|
||||
# max_len = len(hi) if len(hi) >= len(lo) else len(lo)
|
||||
# hi = '{:032b}'.format(args[0].get_value(table))[-max_len:]
|
||||
# lo = '{:032b}'.format(args[1].get_value(table))[-max_len:]
|
||||
return hi+lo
|
||||
return lambda table=None: cat(args, table)
|
||||
elif isinstance(op, (AsUInt, AsSInt)):
|
||||
return lambda table=None: int(args[0].get_value(table))
|
||||
elif isinstance(op, AsClock):
|
||||
return lambda table=None: int(args[0].get_value(table)) % 2
|
||||
else:
|
||||
return lambda table=None: None
|
||||
|
||||
def compile_e(self, mname: str, expr: Expression) -> Expression:
|
||||
if isinstance(expr, (Reference, SubField, SubAccess, SubIndex)):
|
||||
names = []
|
||||
return self.gen_working_ir(mname, names, expr)
|
||||
elif isinstance(expr, DoPrim):
|
||||
args = list(map(lambda arg: self.compile_e(mname, arg), expr.args))
|
||||
consts = [WInt(const) for const in expr.consts]
|
||||
return WDoPrim(expr, self.compile_op(expr.op, args, consts))
|
||||
elif isinstance(expr, Mux):
|
||||
cond = self.compile_e(mname, expr.cond)
|
||||
tval = self.compile_e(mname, expr.tval)
|
||||
fval = self.compile_e(mname, expr.fval)
|
||||
return WMux(expr, lambda table=None: tval.get_value(table) if cond.get_value(table) else fval.get_value(table))
|
||||
elif isinstance(expr, ValidIf):
|
||||
cond = self.compile_e(mname, expr.cond)
|
||||
value = self.compile_e(mname, expr.value)
|
||||
return WValidIf(expr, lambda table=None: value.get_value(table) if cond.get_value(table) else None)
|
||||
elif isinstance(expr, UIntLiteral):
|
||||
return WUIntLiteral(expr)
|
||||
elif isinstance(expr, SIntLiteral):
|
||||
return WSIntLiteral(expr)
|
||||
else:
|
||||
return expr
|
||||
|
||||
def compile_s(self, mname: str, s: Statement) -> Statement:
|
||||
if isinstance(s, EmptyStmt):
|
||||
return EmptyStmt()
|
||||
elif isinstance(s, Conditionally):
|
||||
return Conditionally(self.compile_e(mname, s.pred), self.compile_s(mname, s.conseq), self.compile_s(mname, s.alt), s.info)
|
||||
elif isinstance(s, Block):
|
||||
insts = [sx for sx in s.stmts if isinstance(sx, DefInstance)]
|
||||
nodes = [sx for sx in s.stmts if isinstance(sx, DefNode)]
|
||||
wires = [sx for sx in s.stmts if isinstance(sx, DefWire)]
|
||||
regs = [sx for sx in s.stmts if isinstance(sx, DefRegister)]
|
||||
mems = [sx for sx in s.stmts if isinstance(sx, WDefMemory)]
|
||||
cons = [sx for sx in s.stmts if isinstance(sx, Connect)]
|
||||
stmts = regs + mems + insts + wires + nodes + cons
|
||||
return Block([self.compile_s(mname, sx) for sx in stmts])
|
||||
elif isinstance(s, DefRegister):
|
||||
self.add_dag_node(mname, s.name, s.typ)
|
||||
self.symbol_table.set_symbol(mname, s)
|
||||
return DefRegister(s.name,
|
||||
s.typ,
|
||||
self.compile_e(mname, s.clock),
|
||||
self.compile_e(mname, s.reset),
|
||||
self.compile_e(mname, s.init),
|
||||
s.info)
|
||||
elif isinstance(s, WDefMemory):
|
||||
self.add_dag_node(mname, s.name, s.memType)
|
||||
self.symbol_table.set_symbol(mname, s)
|
||||
return s
|
||||
elif isinstance(s, DefInstance):
|
||||
for p in s.ports:
|
||||
self.add_dag_node(mname, f"{s.name}_{p.name}", p.typ)
|
||||
self.symbol_table.set_symbol(mname, s)
|
||||
return s
|
||||
elif isinstance(s, DefWire):
|
||||
self.add_dag_node(mname, s.name, s.typ)
|
||||
self.symbol_table.set_symbol(mname, s)
|
||||
return s
|
||||
elif isinstance(s, DefNode):
|
||||
self.add_dag_node(mname, s.name, s.value.typ)
|
||||
self.add_dag_edge(mname, Reference(s.name, s.value.typ), s.value)
|
||||
self.symbol_table.set_symbol(mname, s)
|
||||
return DefNode(s.name, self.compile_e(mname, s.value), s.info)
|
||||
elif isinstance(s, Connect):
|
||||
self.add_dag_edge(mname, s.loc, s.expr)
|
||||
return Connect(self.compile_e(mname, s.loc), self.compile_e(mname, s.expr), s.info)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def compile_p(self, mname: str, p: Port):
|
||||
self.add_dag_node(mname, p.name, p.typ)
|
||||
self.symbol_table.set_symbol(mname, p)
|
||||
return p
|
||||
|
||||
def compile_m(self, m: DefModule):
|
||||
if isinstance(m, Module):
|
||||
self.dags[m.name] = DAG()
|
||||
self.symbol_table.set_module(m.name)
|
||||
ports = list(map(lambda p: self.compile_p(m.name, p), m.ports))
|
||||
body = self.compile_s(m.name, m.body)
|
||||
return Module(m.name, ports, body, m.typ, m.info)
|
||||
elif isinstance(m, ExtModule):
|
||||
self.symbol_table.set_module(m.name)
|
||||
ports = list(map(lambda p: self.compile_p(m.name, p), m.ports))
|
||||
return ExtModule(m.name, ports, m.defname, m.typ, m.info)
|
||||
|
||||
def compile(self, c: Circuit):
|
||||
for m in c.modules:
|
||||
self.modules[m.name] = m
|
||||
modules = list(map(lambda m: self.compile_m(m), c.modules))
|
||||
return Circuit(modules, c.main, c.info), self.dags
|
|
@ -0,0 +1,6 @@
|
|||
class TesterException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
|
@ -0,0 +1,244 @@
|
|||
from typing import Dict, List
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.ir.low_prim import *
|
||||
from pyhcl.tester.compiler import TesterCompiler
|
||||
from pyhcl.tester.symbol_table import SymbolTable
|
||||
from pyhcl.tester.clock_stepper import SingleClockStepper
|
||||
|
||||
from pyhcl.passes.check_form import CheckHighForm
|
||||
from pyhcl.passes.check_types import CheckTypes
|
||||
from pyhcl.passes.check_flows import CheckFlow
|
||||
from pyhcl.passes.check_widths import CheckWidths
|
||||
from pyhcl.passes.auto_inferring import AutoInferring
|
||||
from pyhcl.passes.replace_subaccess import ReplaceSubaccess
|
||||
from pyhcl.passes.expand_aggregate import ExpandAggregate
|
||||
from pyhcl.passes.expand_whens import ExpandWhens
|
||||
from pyhcl.passes.expand_memory import ExpandMemory
|
||||
from pyhcl.passes.handle_instance import HandleInstance
|
||||
from pyhcl.passes.optimize import Optimize
|
||||
from pyhcl.passes.remove_access import RemoveAccess
|
||||
from pyhcl.passes.utils import AutoName
|
||||
|
||||
class TesterExecuter:
|
||||
def __init__(self, circuit: Circuit):
|
||||
self.circuit = circuit
|
||||
self.symbol_table = SymbolTable()
|
||||
self.reg_table = {}
|
||||
self.mem_table = {}
|
||||
self.clock_table = {}
|
||||
self.inputchange = False
|
||||
|
||||
|
||||
def handle_name(self, name):
|
||||
names = name.split(".")
|
||||
names.reverse()
|
||||
return names
|
||||
|
||||
def get_inputchange(self):
|
||||
return self.inputchange
|
||||
|
||||
def get_ref_name(self, e: Expression):
|
||||
if isinstance(e, SubField):
|
||||
return self.get_ref_name(e.expr)
|
||||
elif isinstance(e, SubIndex):
|
||||
return self.get_ref_name(e.expr)
|
||||
elif isinstance(e, SubAccess):
|
||||
return self.get_ref_name(e.expr)
|
||||
else:
|
||||
return e.name
|
||||
|
||||
def execute_stmt(self, m: Module, stmt: Statement, table=None):
|
||||
if isinstance(stmt, Connect):
|
||||
if stmt.loc.expr.serialize() in self.reg_table:
|
||||
if self.reg_table[stmt.loc.expr.serialize()].reset.get_value(table) == 0:
|
||||
if stmt.expr.get_value(table) is not None:
|
||||
stmt.loc.set_value(stmt.expr.get_value(table), table)
|
||||
else:
|
||||
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.loc.expr.name),
|
||||
self.reg_table[stmt.loc.expr.serialize()].init.get_value(table), table)
|
||||
elif stmt.loc.expr.serialize() in self.mem_table:
|
||||
mem_data = stmt.loc.expr.serialize()
|
||||
mem = mem_data.split("_")[0]
|
||||
mem_addr = mem_data.replace("data", "addr")
|
||||
mem_en = mem_data.replace("data", "en")
|
||||
mem_mask = mem_data.replace("data", "mask")
|
||||
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_en), table) > 0 and \
|
||||
self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_mask), table) > 0:
|
||||
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table) is not None:
|
||||
self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)]\
|
||||
= self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)
|
||||
elif stmt.expr.expr.serialize() in self.mem_table:
|
||||
mem_data = stmt.expr.expr.serialize()
|
||||
mem = mem_data.split("_")[0]
|
||||
mem_addr = mem_data.replace("data", "addr")
|
||||
mem_en = mem_data.replace("data", "en")
|
||||
if self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_en), table) > 0:
|
||||
if self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)] is not None:
|
||||
self.symbol_table.set_symbol_value(m.name, self.handle_name(mem_data),
|
||||
self.symbol_table[m.name][mem][self.symbol_table.get_symbol_value(m.name, self.handle_name(mem_addr), table)], table)
|
||||
else:
|
||||
if stmt.expr.get_value(table) is not None:
|
||||
stmt.loc.set_value(stmt.expr.get_value(table), table)
|
||||
elif isinstance(stmt, DefNode):
|
||||
self.symbol_table.set_symbol_value(m.name, self.handle_name(stmt.name), stmt.value.get_value(table), table)
|
||||
elif isinstance(stmt, WDefMemory):
|
||||
for rw in stmt.writers:
|
||||
mem_table.append(f"{stmt.name}_{rw}_data")
|
||||
for re in stmt.readers:
|
||||
mem_table.append(f"{stmt.name}_{re}_data")
|
||||
elif isinstance(stmt, Block):
|
||||
for s in stmt.stmts:
|
||||
self.execute_stmt(s, table)
|
||||
|
||||
def execute_module(self, m: Module, ms: Dict[str, DefModule], table=None):
|
||||
execute_stmts = OrderedDict()
|
||||
instances = OrderedDict()
|
||||
|
||||
def get_in_port_name(name: str, t: Type, d: Direction) -> List[str]:
|
||||
if isinstance(d, Input) and isinstance(t, (UIntType, SIntType, ClockType, ResetType, AsyncResetType)):
|
||||
return [name]
|
||||
elif isinstance(d, Input) and isinstance(t, (VectorType, MemoryType)):
|
||||
names = []
|
||||
pnames = get_in_port_name(name, t.typ, d)
|
||||
for pn in pnames:
|
||||
for i in range(t.size):
|
||||
names.append(f"{pn}[{i}]")
|
||||
return names
|
||||
elif isinstance(t, BundleType):
|
||||
names = []
|
||||
for f in t.fields:
|
||||
pnames = []
|
||||
if isinstance(d, Input) and isinstance(f.flip, Default):
|
||||
pnames += get_in_port_name(f.name, f.typ, d)
|
||||
elif isinstance(d, Output) and isinstance(f.flip, Flip):
|
||||
pnames += get_in_port_name(f.name, f.typ, Input())
|
||||
for pn in pnames:
|
||||
names.append(f"{name}.{pn}")
|
||||
return names
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_out_port_name(name: str, t: Type, d: Direction) -> List[str]:
|
||||
if isinstance(d, Output) and isinstance(t, (UIntType, SIntType)):
|
||||
return [name]
|
||||
elif isinstance(d, Output) and isinstance(t, (VectorType, MemoryType)):
|
||||
names = []
|
||||
pnames = get_in_port_name(name, t.typ, d)
|
||||
for pn in pnames:
|
||||
for i in range(t.size):
|
||||
names.append(f"{pn}[{i}]")
|
||||
return names
|
||||
elif isinstance(t, BundleType):
|
||||
names = []
|
||||
for f in t.fields:
|
||||
pnames = []
|
||||
if isinstance(d, Output) and isinstance(f.flip, Default):
|
||||
pnames += get_out_port_name(f.name, f.typ, d)
|
||||
elif isinstance(d, Input) and isinstance(f.flip, Flip):
|
||||
pnames += get_out_port_name(f.name, f.typ, Output())
|
||||
for pn in pnames:
|
||||
names.append(f"{name}.{pn}")
|
||||
return names
|
||||
else:
|
||||
return []
|
||||
|
||||
def _deal_stmt(s: Statement):
|
||||
if isinstance(s, Block):
|
||||
for stmt in s.stmts:
|
||||
_deal_stmt(stmt)
|
||||
elif isinstance(s, Connect):
|
||||
execute_stmts[s.loc.expr.serialize()] = s
|
||||
elif isinstance(s, DefNode):
|
||||
execute_stmts[s.name] = s
|
||||
elif isinstance(s, DefRegister):
|
||||
self.reg_table[s.name] = s
|
||||
elif isinstance(s, WDefMemory):
|
||||
self.mem_table[s.name] = s
|
||||
elif isinstance(s, DefInstance):
|
||||
instances[s.name] = s
|
||||
|
||||
_deal_stmt(m.body)
|
||||
|
||||
for sx in m.body.stmts:
|
||||
self.execute_stmt(m, sx, table)
|
||||
|
||||
for ins in instances:
|
||||
ref_module_name = instances[ins].module
|
||||
ref_module = ms[ref_module_name]
|
||||
ref_table = deepcopy(self.symbol_table.table[ref_module_name])
|
||||
|
||||
module_inputs = []
|
||||
for p in ref_module.ports:
|
||||
module_inputs += get_in_port_name(p.name, p.typ, p.direction)
|
||||
ref_inputs = [f"{ins}_{mi}" for mi in module_inputs]
|
||||
|
||||
for i in range(len(module_inputs)):
|
||||
self.symbol_table.set_symbol_value(ref_module_name,
|
||||
self.handle_name(module_inputs[i]),
|
||||
self.symbol_table.get_symbol_value(m.name, self.handle_name(ref_inputs[i])),
|
||||
ref_table)
|
||||
|
||||
self.execute_module(ref_module, ms, ref_table)
|
||||
|
||||
module_outputs = []
|
||||
for p in ref_module.ports:
|
||||
module_outputs += get_out_port_name(p.name, p.typ, p.direction)
|
||||
ref_outputs = [f"{ins}_{mi}" for mi in module_outputs]
|
||||
|
||||
for i in range(len(module_outputs)):
|
||||
self.symbol_table.set_symbol_value(m.name,
|
||||
self.handle_name(ref_outputs[i]),
|
||||
self.symbol_table.get_symbol_value(ref_module_name, self.handle_name(module_outputs[i]), ref_table))
|
||||
|
||||
for v in self.dags[m.name].travel_graph(ref_outputs):
|
||||
if v in execute_stmts:
|
||||
self.execute_stmt(m, execute_stmts[v], table)
|
||||
|
||||
def init_clock(self, table = None):
|
||||
if table is None:
|
||||
table = self.symbol_table.clock_table
|
||||
for mname in table:
|
||||
if mname not in self.clock_table:
|
||||
self.clock_table[mname] = {}
|
||||
for symbol in table[mname]:
|
||||
self.clock_table[mname][symbol] = SingleClockStepper(mname, symbol, self, table)
|
||||
|
||||
def init_executer(self):
|
||||
AutoName()
|
||||
self.circuit = CheckHighForm(self.circuit).run()
|
||||
self.circuit = AutoInferring().run(self.circuit)
|
||||
self.circuit = CheckTypes().run(self.circuit)
|
||||
self.circuit = CheckFlow().run(self.circuit)
|
||||
self.circuit = CheckWidths().run(self.circuit)
|
||||
self.circuit = ExpandMemory().run(self.circuit)
|
||||
self.circuit = ReplaceSubaccess().run(self.circuit)
|
||||
self.circuit = ExpandAggregate().run(self.circuit)
|
||||
self.circuit = RemoveAccess().run(self.circuit)
|
||||
self.circuit = ExpandWhens().run(self.circuit)
|
||||
self.circuit = HandleInstance().run(self.circuit)
|
||||
self.circuit = Optimize().run(self.circuit)
|
||||
self.compiler = TesterCompiler(self.symbol_table)
|
||||
self.compiled_circuit, self.dags = self.compiler.compile(self.circuit)
|
||||
self.init_clock()
|
||||
|
||||
def set_value(self, mname: str, name: str, singal: int):
|
||||
self.inputchange = True
|
||||
self.symbol_table.set_symbol_value(mname, self.handle_name(name), singal)
|
||||
|
||||
def get_value(self, mname: str, name: str):
|
||||
if self.inputchange:
|
||||
self.execute(mname)
|
||||
self.inputchange = False
|
||||
return self.symbol_table.get_symbol_value(mname, self.handle_name(name))
|
||||
|
||||
def step(self, n: int, mname: str):
|
||||
if n > 0:
|
||||
for name in self.clock_table[mname]:
|
||||
self.clock_table[mname][name].run(n)
|
||||
|
||||
def execute(self, mname: str):
|
||||
ms = {m.name: m for m in self.compiled_circuit.modules}
|
||||
m = ms[mname]
|
||||
self.execute_module(m, ms)
|
|
@ -0,0 +1,72 @@
|
|||
from pyhcl.ir.low_ir import *
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SymbolTable:
|
||||
table = {}
|
||||
clock_table = {}
|
||||
|
||||
def gen_typ(self, typ: Type):
|
||||
if isinstance(typ, (AsyncResetType, ResetType, ClockType, UIntType, SIntType)):
|
||||
return 0
|
||||
elif isinstance(typ, VectorType):
|
||||
return [self.gen_typ(typ.typ) for _ in range(typ.size)]
|
||||
elif isinstance(typ, BundleType):
|
||||
return {f.name: self.gen_typ(f.typ) for f in typ.fields}
|
||||
elif isinstance(typ, MemoryType):
|
||||
return [self.gen_typ(typ.typ) for _ in range(typ.size)]
|
||||
|
||||
def has_symbol(self, mname: str, name: str) -> bool:
|
||||
return name in self.table[mname]
|
||||
|
||||
def set_module(self, mname: str):
|
||||
self.table[mname] = {}
|
||||
self.clock_table[mname] = {}
|
||||
|
||||
def set_symbol(self, mname: str, symbol):
|
||||
if isinstance(symbol, Port):
|
||||
if isinstance(symbol.typ, ClockType):
|
||||
self.clock_table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
if isinstance(symbol, DefWire):
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
elif isinstance(symbol, DefRegister):
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.typ)
|
||||
elif isinstance(symbol, WDefMemory):
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.memType)
|
||||
for rw in symbol.writers:
|
||||
self.table[mname][f"{symbol.name}_{rw}_data"] = self.gen_typ(symbol.dataType)
|
||||
self.table[mname][f"{symbol.name}_{rw}_addr"] = self.gen_typ(UIntType(IntWidth(get_binary_width(symbol.depth))))
|
||||
self.table[mname][f"{symbol.name}_{rw}_clk"] = self.gen_typ(ClockType())
|
||||
self.table[mname][f"{symbol.name}_{rw}_en"] = self.gen_typ(UIntType(IntWidth(1)))
|
||||
self.table[mname][f"{symbol.name}_{rw}_mask"] = self.gen_typ(UIntType(IntWidth(1)))
|
||||
|
||||
for rr in symbol.readers:
|
||||
self.table[mname][f"{symbol.name}_{rr}_data"] = self.gen_typ(symbol.dataType)
|
||||
self.table[mname][f"{symbol.name}_{rr}_addr"] = self.gen_typ(UIntType(IntWidth(get_binary_width(symbol.depth))))
|
||||
self.table[mname][f"{symbol.name}_{rr}_clk"] = self.gen_typ(ClockType())
|
||||
self.table[mname][f"{symbol.name}_{rr}_en"] = self.gen_typ(UIntType(IntWidth(1)))
|
||||
elif isinstance(symbol, DefNode):
|
||||
self.table[mname][symbol.name] = self.gen_typ(symbol.value.typ)
|
||||
elif isinstance(symbol, DefInstance):
|
||||
for p in symbol.ports:
|
||||
name = f"{symbol.name}_{p.name}"
|
||||
self.table[mname][name] = self.gen_typ(p.typ)
|
||||
else:
|
||||
...
|
||||
|
||||
def get_symbol_value(self, mname: str, names: list, table = None):
|
||||
if table is None:
|
||||
table = self.table[mname]
|
||||
ns = names[:]
|
||||
while len(ns) > 1:
|
||||
table = table[ns.pop()]
|
||||
return table[ns.pop()]
|
||||
|
||||
def set_symbol_value(self, mname: str, names: list, signal: int, table = None):
|
||||
if table is None:
|
||||
table = self.table[mname]
|
||||
ns = names[:]
|
||||
while len(ns) > 1:
|
||||
table = table[ns.pop()]
|
||||
table[ns.pop()] = signal
|
||||
return signal
|
|
@ -0,0 +1,24 @@
|
|||
from pyhcl.tester.executer import TesterExecuter
|
||||
from pyhcl.ir.low_ir import *
|
||||
from pyhcl.dsl.emitter import Emitter
|
||||
|
||||
class Tester:
|
||||
def __init__(self, m: Module):
|
||||
circuit = Emitter.elaborate(m)
|
||||
self.main = circuit.main
|
||||
self.executer = TesterExecuter(circuit)
|
||||
self.executer.init_executer()
|
||||
|
||||
|
||||
def poke(self, name: str, value: int):
|
||||
self.executer.set_value(self.main, name, value)
|
||||
|
||||
def peek(self, name: str) -> int:
|
||||
res = self.executer.get_value(self.main, name)
|
||||
return int(res, 2) if isinstance(res, str) else res
|
||||
|
||||
def expect(self, a, b) -> bool:
|
||||
return a == b
|
||||
|
||||
def step(self, n):
|
||||
self.executer.step(n, self.main)
|
|
@ -0,0 +1,131 @@
|
|||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from pyhcl.tester.exception import TesterException
|
||||
class DAG:
|
||||
""" Directed acyclic graph implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self.graph = OrderedDict()
|
||||
|
||||
def add_node(self, name: str, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if name in graph:
|
||||
...
|
||||
else:
|
||||
graph[name] = set()
|
||||
|
||||
def add_node_if_not_exists(self, name: str, graph = None):
|
||||
try:
|
||||
self.add_node(name, graph = graph)
|
||||
except TesterException as e:
|
||||
raise e
|
||||
|
||||
def delete_node(self, name: str, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if name not in graph:
|
||||
raise TesterException(f'node {name} is not exists.')
|
||||
graph.pop(name)
|
||||
for _, edges in graph.items():
|
||||
if name in edges:
|
||||
edges.remove(name)
|
||||
|
||||
def delete_node_if_exists(self, name: str, graph = None):
|
||||
try:
|
||||
self.delete_node(name, graph = graph)
|
||||
except TesterException as e:
|
||||
raise e
|
||||
|
||||
def add_edge(self, ind_node, dep_node, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if ind_node not in graph or dep_node not in graph:
|
||||
raise TesterException(f'nodes do not exist in graph.')
|
||||
test_graph = deepcopy(graph)
|
||||
test_graph[ind_node].add(dep_node)
|
||||
is_valid, msg = self.validate(test_graph)
|
||||
if is_valid:
|
||||
graph[ind_node].add(dep_node)
|
||||
else:
|
||||
raise TesterException(f'Loop do exist in graph: {msg}')
|
||||
|
||||
def delete_edge(self, ind_node, dep_node, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if dep_node not in graph.get(ind_node, []):
|
||||
raise TesterException(f'This edge does not exist in graph')
|
||||
graph[ind_node].remove(dep_node)
|
||||
|
||||
def ind_nodes(self, graph = None):
|
||||
if graph == None:
|
||||
graph = self.graph
|
||||
dep_nodes = set(
|
||||
node for deps in graph.values() for node in deps
|
||||
)
|
||||
|
||||
return [node for node in graph.keys() if node not in dep_nodes]
|
||||
|
||||
def topological_sort(self, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
result = []
|
||||
in_degree = defaultdict(lambda: 0)
|
||||
|
||||
for u in graph:
|
||||
for v in graph[u]:
|
||||
in_degree[v] += 1
|
||||
|
||||
ready = [node for node in graph if not in_degree[node]]
|
||||
|
||||
while ready:
|
||||
u = ready.pop()
|
||||
result.append(u)
|
||||
for v in graph[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0:
|
||||
ready.append(v)
|
||||
if len(result) == len(graph):
|
||||
return result
|
||||
else:
|
||||
raise TesterException(f'graph is not acyclic.')
|
||||
|
||||
|
||||
def validate(self, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
if len(self.ind_nodes(graph)) == 0:
|
||||
return False, 'no independent nodes detected.'
|
||||
try:
|
||||
self.topological_sort(graph)
|
||||
except TesterException:
|
||||
return False, 'graph is not acyclic.'
|
||||
return True, 'valid'
|
||||
|
||||
def visit_graph(self, graph = None):
|
||||
visited = []
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
for v in graph:
|
||||
for u in graph[v]:
|
||||
visited.append(f'{v} -> {u}')
|
||||
return visited
|
||||
|
||||
def travel_graph(self, init: list, graph = None):
|
||||
if graph is None:
|
||||
graph = self.graph
|
||||
visits = []
|
||||
history = set()
|
||||
while len(init) > 0:
|
||||
visited = init.pop(0)
|
||||
if len(graph[visited]) > 0:
|
||||
for u in graph[visited]:
|
||||
if u not in init and u not in history:
|
||||
init.append(u)
|
||||
history.add(visited)
|
||||
visits.append(visited)
|
||||
return visits
|
||||
|
||||
|
||||
def size(self):
|
||||
return len(self.graph)
|
|
@ -0,0 +1,154 @@
|
|||
from pyclbr import Function
|
||||
from pyhcl.ir.low_ir import *
|
||||
@dataclass(frozen=True)
|
||||
class WUIntLiteral(Expression):
|
||||
expr: Expression
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.expr.value
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WSIntLiteral(Expression):
|
||||
expr: Expression
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.expr.value
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WReference(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
set_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def set_value(self, *args) -> int:
|
||||
return self.set_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WSubField(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
set_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def set_value(self, *args) -> int:
|
||||
return self.set_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WSubIndex(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
set_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def set_value(self, *args) -> int:
|
||||
return self.set_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WSubAccess(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
set_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def set_value(self, *args) -> int:
|
||||
return self.set_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WMux(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WValidIf(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WDoPrim(Expression):
|
||||
expr: Expression
|
||||
get_func: Function
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.get_func(*args)
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WInt(Expression):
|
||||
value: int
|
||||
|
||||
def get_value(self, *args) -> int:
|
||||
return self.value
|
||||
|
||||
def serialize(self) -> str:
|
||||
...
|
||||
|
||||
def verilog_serialize(self) -> str:
|
||||
...
|
Loading…
Reference in New Issue