826 lines
31 KiB
Python
826 lines
31 KiB
Python
import pathlib
|
|
import shutil
|
|
import collections
|
|
import logging
|
|
import argparse
|
|
from string import Template
|
|
import re
|
|
from .. import util
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
def parse_verilog_input_info(inputs):
|
|
arg_count = 0
|
|
rePointerData = re.compile(r"(\S+)_datain")
|
|
rePointerSignals = re.compile(r"(\S+)_req_full_n|(\S+)_rsp_empty_n|(\S+)_datain")
|
|
reApSignals = re.compile(r"ap_\S+")
|
|
input_info = collections.OrderedDict()
|
|
for k, v in inputs.items():
|
|
# If matched for pointer input
|
|
matchPointerData = rePointerData.match(k)
|
|
matchPointerSignals = rePointerSignals.match(k)
|
|
matchApSignals = reApSignals.match(k)
|
|
|
|
arg_type = None
|
|
if matchPointerData:
|
|
arg_name = matchPointerData.group(1)
|
|
arg_type = 'pointer'
|
|
if not matchPointerSignals and not matchApSignals:
|
|
arg_name = k
|
|
arg_type = 'scalar'
|
|
if arg_type is not None:
|
|
input_info[arg_name] = {
|
|
'arg_idx': arg_count,
|
|
'type': arg_type,
|
|
'width': v['width']
|
|
}
|
|
arg_count += 1
|
|
|
|
# Check for correctness
|
|
if len(input_info) > 2:
|
|
logger.critical("Only accept function with no more than 2 arguments!")
|
|
raise
|
|
|
|
for k, v in inputs.items():
|
|
matchPointerSignals = rePointerSignals.match(k)
|
|
if matchPointerSignals:
|
|
arg_name = None
|
|
if matchPointerSignals.group(1) is not None:
|
|
arg_name = matchPointerSignals.group(1)
|
|
elif matchPointerSignals.group(2) is not None:
|
|
arg_name = matchPointerSignals.group(2)
|
|
elif matchPointerSignals.group(3) is not None:
|
|
arg_name = matchPointerSignals.group(3)
|
|
else:
|
|
logger.critical("Unexpected Signal {}!".format(k))
|
|
raise
|
|
if arg_name in list(input_info.keys()):
|
|
if 'num_signal' in list(input_info[arg_name].keys()):
|
|
input_info[arg_name]['num_signal'] += 1
|
|
else:
|
|
input_info[arg_name]['num_signal'] = 1
|
|
else:
|
|
# arg should be created
|
|
logger.critical("Unexpected Signal {}!".format(k))
|
|
raise
|
|
|
|
matchApSignals = reApSignals.match(k)
|
|
if not matchApSignals and not matchPointerSignals:
|
|
arg_name = k
|
|
if arg_name in list(input_info.keys()):
|
|
if 'num_signal' in list(input_info[arg_name].keys()):
|
|
input_info[arg_name]['num_signal'] += 1
|
|
else:
|
|
input_info[arg_name]['num_signal'] = 1
|
|
else:
|
|
# arg should be created
|
|
logger.critical("Unexpected Signal {}!".format(k))
|
|
raise
|
|
|
|
for k, v in input_info.items():
|
|
if v['type'] == 'pointer':
|
|
if v['num_signal'] != 3:
|
|
logger.critical("The AP bus interfance did not generate expected number of inputs!")
|
|
raise
|
|
elif v['type'] == 'scalar':
|
|
if v['num_signal'] != 1:
|
|
logger.critical("The AP bus interfance did not generate expected number of inputs!")
|
|
raise
|
|
return input_info
|
|
|
|
|
|
def generate_rocc_input_info(input_info):
|
|
scalar_data_width = []
|
|
scalar_idx = []
|
|
ptr_addr_width = []
|
|
ptr_data_width = []
|
|
ptr_idx = []
|
|
|
|
for k, v in input_info.items():
|
|
if v['type'] == 'scalar':
|
|
scalar_data_width.append(str(v['width']))
|
|
scalar_idx.append(str(v['arg_idx']))
|
|
elif v['type'] == 'pointer':
|
|
ptr_addr_width.append(str(64))
|
|
ptr_data_width.append(str(v['width']))
|
|
ptr_idx.append(str(v['arg_idx']))
|
|
|
|
scalar_data_width_arr = ','.join(scalar_data_width)
|
|
scalar_idx_arr = ','.join(scalar_idx)
|
|
ptr_addr_width_arr = ','.join(ptr_addr_width)
|
|
ptr_data_width_arr = ','.join(ptr_data_width)
|
|
ptr_idx_arr = ','.join(ptr_idx)
|
|
|
|
info_dict = {
|
|
'SCALAR_DATA_WIDTH_ARR': scalar_data_width_arr,
|
|
'SCALAR_IDX_ARR': scalar_idx_arr,
|
|
'PTR_ADDR_WIDTH_ARR': ptr_addr_width_arr,
|
|
'PTR_DATA_WIDTH_ARR': ptr_data_width_arr,
|
|
'PTR_IDX_ARR': ptr_idx_arr,
|
|
}
|
|
return info_dict
|
|
|
|
|
|
def get_rocc_scalarIO_count(input_info):
|
|
num_scalar = 0
|
|
for k, v in input_info.items():
|
|
if v['type'] == 'scalar':
|
|
num_scalar += 1
|
|
return num_scalar
|
|
|
|
|
|
def generate_rocc_scalarIO(num_scalar):
|
|
if num_scalar > 0:
|
|
return " val scalar_io = HeterogeneousBag(scalar_io_dataWidths.map(w => Input(UInt(w.W))))\n"
|
|
else:
|
|
return ""
|
|
|
|
|
|
def generate_rocc_scalarIO_stmt0(input_info):
|
|
ret_str = ""
|
|
scalar_idx = 0
|
|
for k, v in input_info.items():
|
|
if v['type'] == 'scalar':
|
|
arg_idx = v['arg_idx']
|
|
ret_str += "accel.io.scalar_io({}) := rs{}\n".format(scalar_idx, arg_idx+1)
|
|
return ret_str
|
|
|
|
|
|
def generate_rocc_scalarIO_stmt1(num_scalar):
|
|
scalar_str = """//Scalar values
|
|
for(i <- 0 until accel.io.scalar_io.length){
|
|
accel.io.scalar_io(i) := cArgs(accel.scalar_io_argLoc(i))
|
|
}"""
|
|
|
|
if num_scalar > 0:
|
|
return scalar_str
|
|
else:
|
|
return ""
|
|
|
|
|
|
def generate_rocc_ap_return_stmt(outputs):
|
|
if 'ap_return' in list(outputs.keys()):
|
|
return "val ap_return = accel.io.ap.rtn\n"
|
|
else:
|
|
return "val ap_return = UInt(4.W)\n"
|
|
|
|
|
|
def generate_vals(io, width):
|
|
if width == 1:
|
|
val = "{}(Bool())".format(io)
|
|
else:
|
|
val = "{}(UInt({}.W))".format(io, width)
|
|
return val
|
|
|
|
|
|
def generate_params(params):
|
|
params_arr = []
|
|
|
|
# Parameters
|
|
for k, v in params.items():
|
|
param_str = "val {} = {}".format(k, v)
|
|
params_arr.append(param_str)
|
|
return params_arr
|
|
|
|
|
|
def generate_args(inputs, outputs):
|
|
reClk = re.compile('ap_clk(.*)')
|
|
#reRst = re.compile('ap_rst(.*)')
|
|
#rstMatch = reRst.match(k)
|
|
args_arr = []
|
|
|
|
# Inputs
|
|
for k, v in inputs.items():
|
|
clkMatch = reClk.match(k)
|
|
val = None
|
|
if clkMatch:
|
|
val = "Input(Clock())"
|
|
else:
|
|
val = generate_vals('Input', v['width'])
|
|
arg_str = "val {} = {}".format(k, val)
|
|
args_arr.append(arg_str)
|
|
# Outputs
|
|
for k, v in outputs.items():
|
|
val = generate_vals('Output', v['width'])
|
|
arg_str = "val {} = {}".format(k, val)
|
|
args_arr.append(arg_str)
|
|
|
|
return args_arr
|
|
|
|
|
|
def generate_opt_ap_signals(inputs, outputs):
|
|
ret_str = ""
|
|
if 'ap_return' in list(outputs.keys()):
|
|
ret_str += " io.ap.rtn := bb.io.ap_return\n"
|
|
if 'ap_rst' in list(inputs.keys()):
|
|
ret_str += " bb.io.ap_rst := reset\n"
|
|
if 'ap_clk' in list(inputs.keys()):
|
|
ret_str += " bb.io.ap_clk := clock\n"
|
|
if 'ap_rst_n' in list(inputs.keys()):
|
|
ret_str += " bb.io.ap_rst_n := !reset.asBool()\n"
|
|
return ret_str
|
|
|
|
|
|
def generate_rocc_assignment(input_info):
|
|
scalar_template = Template(" bb.io.${ARG} := io.scalar_io($IDX)\n")
|
|
ptr_template = Template(""" io.ap_bus($IDX).req.din := bb.io.${ARG}_req_din
|
|
bb.io.${ARG}_req_full_n := io.ap_bus($IDX).req_full_n
|
|
io.ap_bus($IDX).req_write := bb.io.${ARG}_req_write
|
|
bb.io.${ARG}_rsp_empty_n := io.ap_bus($IDX).rsp_empty_n
|
|
io.ap_bus($IDX).rsp_read := bb.io.${ARG}_rsp_read
|
|
io.ap_bus($IDX).req.address := bb.io.${ARG}_address
|
|
bb.io.${ARG}_datain := io.ap_bus($IDX).rsp.datain
|
|
io.ap_bus($IDX).req.dataout := bb.io.${ARG}_dataout
|
|
io.ap_bus($IDX).req.size := bb.io.${ARG}_size
|
|
""")
|
|
|
|
ret_str = ""
|
|
rocc_scalar_idx = 0
|
|
rocc_ptr_idx = 0
|
|
for k, v in input_info.items():
|
|
if v['type'] == 'scalar':
|
|
d = {'ARG': k, 'IDX': rocc_scalar_idx}
|
|
scalar_str = scalar_template.substitute(d)
|
|
ret_str += scalar_str
|
|
rocc_scalar_idx += 1
|
|
elif v['type'] == 'pointer':
|
|
d = {'ARG': k, 'IDX': rocc_ptr_idx}
|
|
ptr_str = ptr_template.substitute(d)
|
|
ret_str += ptr_str
|
|
rocc_ptr_idx += 1
|
|
return ret_str
|
|
|
|
|
|
def parse_verilog_arg_line(line, reArg, args):
|
|
""" Parse a Verilog line for args.
|
|
line: Input line string
|
|
reArg: regex for the arg
|
|
args: dict = { arg name: {'width': arg bitwidth}}
|
|
"""
|
|
ret = False
|
|
reArgWidth = re.compile('\[(.*):(.*)\]\s*(.*)')
|
|
argMatch = reArg.match(line)
|
|
if argMatch:
|
|
argName = argMatch.group(1)
|
|
if argName:
|
|
argWidthMatch = reArgWidth.match(argName)
|
|
end = 0
|
|
start = 0
|
|
size = None
|
|
if argWidthMatch:
|
|
end = argWidthMatch.group(1)
|
|
endMatch = re.match(r"(\S+) - 1", end)
|
|
if endMatch:
|
|
size = endMatch.group(1)
|
|
intMatch = re.match(r"\d+", size)
|
|
if intMatch:
|
|
end = size - 1
|
|
size = None
|
|
start = argWidthMatch.group(2)
|
|
argName = argWidthMatch.group(3)
|
|
if size is None:
|
|
width = int(end) - int(start) + 1
|
|
else:
|
|
width = size
|
|
args[argName] = {'width': width}
|
|
ret = True
|
|
return ret
|
|
|
|
|
|
def parse_verilog_rocc(vpath):
|
|
""" Parse a centrifuge-generated verilog file to extract the information
|
|
needed to generate a RoCC wrapper.
|
|
|
|
vpath: Path to main verilog function file
|
|
|
|
Returns: (inputs, retVal)
|
|
inputs - dict of argument names and the data width
|
|
retVal - boolean indicating whether or not a return value is present
|
|
"""
|
|
|
|
|
|
# Input/Output statements in the verilog. We assume only one module in the file.
|
|
reInput = re.compile('^\s*input\s+(.*)\s*;')
|
|
reOutput = re.compile('^\s*output\s+(.*)\s*;')
|
|
reReturnVal = re.compile('^\s*output\s+\[(.*):(.*)\]\s*ap_return;')
|
|
|
|
inputs = collections.OrderedDict()
|
|
outputs = collections.OrderedDict()
|
|
|
|
logger.info("Parsing: {}".format(vpath))
|
|
with open(vpath, 'r') as vf:
|
|
for line in vf.readlines():
|
|
# test if it is output
|
|
match = parse_verilog_arg_line(line, reInput, inputs)
|
|
if not match:
|
|
# test if it is output
|
|
match = parse_verilog_arg_line(line, reOutput, outputs)
|
|
|
|
logger.info("Inputs: {}".format(inputs))
|
|
logger.info("Outputs: {}".format(outputs))
|
|
return (inputs, outputs)
|
|
|
|
|
|
def generate_chisel_rocc(func, idx, inputs, outputs, scala_dir, template_dir):
|
|
|
|
##########################################################
|
|
logger.info("Generating RoCC BlackBox file ...")
|
|
template_name = 'chisel_rocc_blackbox_scala_template'
|
|
template_path = template_dir / template_name
|
|
|
|
# Generate arguments
|
|
args_arr = generate_args(inputs, outputs)
|
|
args_str = "\n ".join(args_arr)
|
|
|
|
# Generate spec for args
|
|
input_info = parse_verilog_input_info(inputs)
|
|
info_dict = generate_rocc_input_info(input_info)
|
|
return_width = outputs['ap_return']['width'] if 'ap_return' in list(outputs.keys()) else 1
|
|
num_scalar = get_rocc_scalarIO_count(input_info)
|
|
scalar_io_str = generate_rocc_scalarIO(num_scalar)
|
|
|
|
# Generate signal assignments
|
|
# For optional vivado ap signals
|
|
ap_return_rst_clk_str = generate_opt_ap_signals(inputs, outputs)
|
|
signal_assignment_str = generate_rocc_assignment(input_info)
|
|
|
|
chisel_dict = {
|
|
"FUNC": func,
|
|
"ARGS": args_str,
|
|
'SCALAR_DATA_WIDTH_ARR': info_dict['SCALAR_DATA_WIDTH_ARR'],
|
|
'SCALAR_IDX_ARR': info_dict['SCALAR_IDX_ARR'],
|
|
'PTR_ADDR_WIDTH_ARR': info_dict['PTR_ADDR_WIDTH_ARR'],
|
|
'PTR_DATA_WIDTH_ARR': info_dict['PTR_DATA_WIDTH_ARR'],
|
|
'PTR_IDX_ARR': info_dict['PTR_IDX_ARR'],
|
|
'RETURN_WIDTH': return_width,
|
|
'SCALAR_IO': scalar_io_str,
|
|
'AP_RETURN_RST_CLK': ap_return_rst_clk_str,
|
|
'SIGNAL_ASSIGNMENT': signal_assignment_str,
|
|
|
|
}
|
|
scala_path = scala_dir / pathlib.Path(func + '_blackbox.scala')
|
|
util.generate_file(template_path, chisel_dict, scala_path)
|
|
logger.info("\t\tGenerate rocc_blackbox code in CHISEL: {}".format(scala_path))
|
|
|
|
##########################################################
|
|
logger.info("Generating RoCC Control file ...")
|
|
template_name = 'chisel_rocc_accel_scala_template'
|
|
template_path = template_dir / template_name
|
|
|
|
scalar_io_assignment0 = generate_rocc_scalarIO_stmt0(input_info)
|
|
ap_return_assignment = generate_rocc_ap_return_stmt(outputs)
|
|
num_scalar = get_rocc_scalarIO_count(input_info)
|
|
scalar_io_assignment1 = generate_rocc_scalarIO_stmt1(num_scalar)
|
|
chisel_dict = {
|
|
"FUNC": func,
|
|
'SCALAR_IO_ASSIGNMENT0': scalar_io_assignment0,
|
|
'AP_RETURN_ASSIGNMENT': ap_return_assignment,
|
|
'SCALAR_IO_ASSIGNMENT1': scalar_io_assignment1,
|
|
}
|
|
|
|
scala_path = scala_dir / pathlib.Path(func + '_accel.scala')
|
|
util.generate_file(template_path, chisel_dict, scala_path)
|
|
logger.info("\t\tGenerate rocc_accel code in CHISEL: {}".format(scala_path))
|
|
|
|
##########################################################
|
|
logger.info("Copying Vivado HLS Interface file ...");
|
|
src_path = template_dir / 'ap_bus_scala_template'
|
|
dst_path = scala_dir / 'ap_bus.scala'
|
|
shutil.copy(str(src_path), str(dst_path))
|
|
|
|
##########################################################
|
|
logger.info("Copying ROCC Memory Controller file ...");
|
|
src_path = template_dir / 'memControllerComponents_scala_template'
|
|
dst_path = scala_dir / 'memControllerComponents.scala'
|
|
shutil.copy(str(src_path), str(dst_path))
|
|
|
|
##########################################################
|
|
logger.info("Copying RoCC Controller Utilities file ...");
|
|
src_path = template_dir / 'controlUtils_scala_template'
|
|
dst_path = scala_dir / 'controlUtils.scala'
|
|
shutil.copy(str(src_path), str(dst_path))
|
|
|
|
|
|
def parse_verilog_tl(vpath):
|
|
"""Parse a centrifuge-generated verilog file to extract the information
|
|
needed to generate tilelink wrappers.
|
|
|
|
vpath: Path to the verilog file containing control signal info (path-like object)
|
|
|
|
returns: (returnSize, Args)
|
|
retVal: MmioArg representing the return value (or None if no return).
|
|
Args: List of MmioArg representing the arguments to the accelerated function
|
|
"""
|
|
reInput = re.compile('^\s*input\s+(.*)\s*;')
|
|
reOutput = re.compile('^\s*output\s+(.*)\s*;')
|
|
reParam = re.compile("parameter\s+(C_\S+) =\s+(.*);")
|
|
|
|
inputs = collections.OrderedDict()
|
|
outputs = collections.OrderedDict()
|
|
params = collections.OrderedDict()
|
|
buses = collections.OrderedDict()
|
|
|
|
logger.info("Parsing: {}".format(vpath))
|
|
with open(vpath, 'r') as vf:
|
|
for line in vf.readlines():
|
|
# test if it is input
|
|
match = parse_verilog_arg_line(line, reInput, inputs)
|
|
if not match:
|
|
# test if it is output
|
|
match = parse_verilog_arg_line(line, reOutput, outputs)
|
|
if not match:
|
|
paramMatch = reParam.match(line)
|
|
if paramMatch:
|
|
param = paramMatch.group(1)
|
|
width = paramMatch.group(2)
|
|
params[param] = width
|
|
busMatch = re.match(r"C_M_AXI_(\S+)_DATA_WIDTH", param)
|
|
if busMatch:
|
|
bus = busMatch.group(1).lower()
|
|
buses[bus] = {'width': width}
|
|
|
|
logger.info("Inputs: {}".format(inputs))
|
|
logger.info("Outputs: {}".format(outputs))
|
|
logger.info("Paramters: {}".format(params))
|
|
logger.info("Buses: {}".format(buses))
|
|
return (inputs, outputs, params, buses)
|
|
|
|
|
|
def generate_tl_assignment(buses):
|
|
ret_str = ""
|
|
template = Template("""
|
|
val node_${BUS_NAME} = AXI4MasterNode(Seq(AXI4MasterPortParameters(
|
|
masters = Seq(AXI4MasterParameters(
|
|
name = "axil_hub_mem_out_${IDX}",
|
|
id = IdRange(0, numInFlight),
|
|
aligned = true,
|
|
maxFlight = Some(8)
|
|
))
|
|
)
|
|
))
|
|
""")
|
|
|
|
for idx, (k, v) in enumerate(buses.items()):
|
|
d = {'BUS_NAME': k, 'IDX': idx}
|
|
ret_str += template.substitute(d)
|
|
|
|
return ret_str
|
|
|
|
|
|
def generate_AXI_signal(matchInput, template):
|
|
assert(matchInput.group(1) is not None)
|
|
assert(matchInput.group(2) is not None)
|
|
bus = matchInput.group(1)
|
|
signal_type = matchInput.group(2).lower()
|
|
return template.format(bus, signal_type)
|
|
|
|
|
|
def construct_axi_regex(regex):
|
|
axi_types = ['m', 's']
|
|
ret_dict = {}
|
|
for axi_type in axi_types:
|
|
ret_dict[axi_type] = re.compile(axi_type + regex)
|
|
return ret_dict
|
|
|
|
|
|
def generate_tl_module_stmt(inputs, outputs, buses):
|
|
ret_str = ""
|
|
|
|
bus_stmt_arr = []
|
|
for k, _ in buses.items():
|
|
bus_stmt = "val (out_{0}, edge_{0}) = outer.node_{0}.out(0)".format(k)
|
|
bus_stmt_arr.append(bus_stmt)
|
|
ret_str += "\n ".join(bus_stmt_arr)
|
|
ret_str += "\n"
|
|
|
|
ret_str += generate_opt_ap_signals(inputs, outputs)
|
|
ret_str += "\n"
|
|
|
|
reAXI = re.compile('^(m_axi|s_axi)\S+$')
|
|
|
|
# Input Signals Regex
|
|
reAWWARREADY = construct_axi_regex('_axi_(.*)_(AW|W|AR)READY$')
|
|
reRBVALID = construct_axi_regex('_axi_(.*)_(R|B)VALID$')
|
|
reRDATA = construct_axi_regex('_axi_(.*)_(R)DATA$')
|
|
reRLAST = construct_axi_regex('_axi_(.*)_(R)LAST$')
|
|
reRBID = construct_axi_regex('_axi_(.*)_(R|B)ID$')
|
|
reRBUSER = construct_axi_regex('_axi_(.*)_(R|B)USER$')
|
|
reRBRESP = construct_axi_regex('_axi_(.*)_(R|B)RESP$')
|
|
|
|
reAWWARVALID = construct_axi_regex('_axi_(.*)_(AW|W|AR)VALID$')
|
|
reAWARADDR = construct_axi_regex('_axi_(.*)_(AW|AR)ADDR$')
|
|
reWDATA = construct_axi_regex('_axi_(.*)_(W)DATA$')
|
|
reWSTRB = construct_axi_regex('_axi_(.*)_(W)STRB$')
|
|
reRBREADY = construct_axi_regex('_axi_(.*)_(R|B)READY$')
|
|
reWLAST = construct_axi_regex('_axi_(.*)_(W)LAST$')
|
|
reWID = construct_axi_regex('_axi_(.*)_(W)ID$')
|
|
reWUSER = construct_axi_regex('_axi_(.*)_(W)USER$')
|
|
|
|
reAWARID = construct_axi_regex('_axi_(.*)_(AW|AR)ID')
|
|
reAWARLEN = construct_axi_regex('_axi_(.*)_(AW|AR)LEN$')
|
|
reAWARSIZE = construct_axi_regex('_axi_(.*)_(AW|AR)SIZE$')
|
|
reAWARBURST = construct_axi_regex('_axi_(.*)_(AW|AR)BURST$')
|
|
reAWARLOCK = construct_axi_regex('_axi_(.*)_(AW|AR)LOCK$')
|
|
reAWARCACHE = construct_axi_regex('_axi_(.*)_(AW|AR)CACHE$')
|
|
reAWARPROT = construct_axi_regex('_axi_(.*)_(AW|AR)PROT$')
|
|
reAWARQOS = construct_axi_regex('_axi_(.*)_(AW|AR)QOS$')
|
|
reAWARREGION = construct_axi_regex('_axi_(.*)_(AW|AR)REGION$')
|
|
reAWARUSER = construct_axi_regex('_axi_(.*)_(AW|AR)USER$')
|
|
|
|
bus_assign_arr = []
|
|
for k, v in inputs.items():
|
|
matchAXI = reAXI.match(k)
|
|
if matchAXI:
|
|
matchAWWARREADY = reAWWARREADY['m'].match(k)
|
|
matchRBVALID = reRBVALID['m'].match(k)
|
|
matchRDATA = reRDATA['m'].match(k)
|
|
matchRLAST = reRLAST['m'].match(k)
|
|
matchRBID = reRBID['m'].match(k)
|
|
matchRBUSER = reRBUSER['m'].match(k)
|
|
matchRBRESP = reRBRESP['m'].match(k)
|
|
|
|
matchAWWARVALID = reAWWARVALID['s'].match(k)
|
|
matchAWARADDR = reAWARADDR['s'].match(k)
|
|
matchWDATA = reWDATA['s'].match(k)
|
|
matchWSTRB = reWSTRB['s'].match(k)
|
|
matchRBREADY = reRBREADY['s'].match(k)
|
|
|
|
assign_str = None
|
|
in_str = "bb.io." + k
|
|
if matchAWWARREADY:
|
|
assign_str = generate_AXI_signal(matchAWWARREADY,
|
|
in_str + " := out_{0}.{1}.ready")
|
|
elif matchRBVALID:
|
|
assign_str = generate_AXI_signal(matchRBVALID,
|
|
in_str + " := out_{0}.{1}.valid")
|
|
elif matchRDATA:
|
|
assign_str = generate_AXI_signal(matchRDATA,
|
|
in_str + " := out_{0}.{1}.bits.data")
|
|
elif matchRLAST:
|
|
assign_str = generate_AXI_signal(matchRLAST,
|
|
in_str + " := out_{0}.{1}.bits.last")
|
|
elif matchRBID:
|
|
assign_str = generate_AXI_signal(matchRBID,
|
|
in_str + " := out_{0}.{1}.bits.id")
|
|
elif matchRBUSER: # omit user signal
|
|
assign_str = ""
|
|
elif matchRBRESP:
|
|
assign_str = generate_AXI_signal(matchRBRESP,
|
|
in_str + " := out_{0}.{1}.bits.resp")
|
|
elif matchAWWARVALID:
|
|
assign_str = generate_AXI_signal(matchAWWARVALID,
|
|
in_str + " := slave_in.{1}.valid")
|
|
elif matchAWARADDR:
|
|
assign_str = generate_AXI_signal(matchAWARADDR,
|
|
in_str + " := slave_in.{1}.bits.addr")
|
|
elif matchWDATA:
|
|
assign_str = generate_AXI_signal(matchWDATA,
|
|
in_str + " := slave_in.{1}.bits.data")
|
|
elif matchWSTRB:
|
|
assign_str = generate_AXI_signal(matchWSTRB,
|
|
in_str + " := slave_in.{1}.bits.strb")
|
|
elif matchRBREADY:
|
|
assign_str = generate_AXI_signal(matchRBREADY,
|
|
in_str + " := slave_in.{1}.ready")
|
|
assert(assign_str is not None)
|
|
bus_assign_arr.append(assign_str)
|
|
|
|
for k, v in outputs.items():
|
|
matchAXI = reAXI.match(k)
|
|
if matchAXI:
|
|
matchAWWARREADY = reAWWARREADY['s'].match(k)
|
|
matchRBVALID = reRBVALID['s'].match(k)
|
|
matchRDATA = reRDATA['s'].match(k)
|
|
matchRBRESP = reRBRESP['s'].match(k)
|
|
|
|
matchAWWARVALID = reAWWARVALID['m'].match(k)
|
|
matchRBREADY = reRBREADY['m'].match(k)
|
|
matchAWARADDR = reAWARADDR['m'].match(k)
|
|
matchAWARID = reAWARID['m'].match(k)
|
|
matchAWARLEN = reAWARLEN['m'].match(k)
|
|
|
|
matchAWARSIZE = reAWARSIZE['m'].match(k)
|
|
matchAWARBURST = reAWARBURST['m'].match(k)
|
|
matchAWARLOCK = reAWARLOCK['m'].match(k)
|
|
matchAWARCACHE = reAWARCACHE['m'].match(k)
|
|
matchAWARPROT = reAWARPROT['m'].match(k)
|
|
|
|
matchAWARQOS = reAWARQOS['m'].match(k)
|
|
matchAWARREGION = reAWARREGION['m'].match(k)
|
|
matchAWARUSER = reAWARUSER['m'].match(k)
|
|
matchWDATA = reWDATA['m'].match(k)
|
|
matchWSTRB = reWSTRB['m'].match(k)
|
|
matchWLAST = reWLAST['m'].match(k)
|
|
matchWID = reWID['m'].match(k)
|
|
matchWUSER = reWUSER['m'].match(k)
|
|
|
|
assign_str = None
|
|
out_str = "bb.io." + k
|
|
if matchAWWARVALID:
|
|
assign_str = generate_AXI_signal(matchAWWARVALID,
|
|
"out_{0}.{1}.valid := " + out_str)
|
|
elif matchRBREADY:
|
|
assign_str = generate_AXI_signal(matchRBREADY,
|
|
"out_{0}.{1}.ready :=" + out_str)
|
|
elif matchAWARADDR:
|
|
assign_str = generate_AXI_signal(matchAWARADDR,
|
|
"out_{0}.{1}.bits.addr := " + out_str)
|
|
elif matchAWARID:
|
|
assign_str = generate_AXI_signal(matchAWARID,
|
|
"out_{0}.{1}.bits.id := " + out_str)
|
|
elif matchAWARLEN:
|
|
assign_str = generate_AXI_signal(matchAWARLEN,
|
|
"out_{0}.{1}.bits.len := " + out_str)
|
|
elif matchAWARSIZE:
|
|
assign_str = generate_AXI_signal(matchAWARSIZE,
|
|
"out_{0}.{1}.bits.size := " + out_str)
|
|
elif matchAWARBURST:
|
|
assign_str = generate_AXI_signal(matchAWARBURST,
|
|
"out_{0}.{1}.bits.burst := " + out_str)
|
|
elif matchAWARLOCK:
|
|
assign_str = generate_AXI_signal(matchAWARLOCK,
|
|
"out_{0}.{1}.bits.lock := " + out_str)
|
|
elif matchAWARCACHE:
|
|
assign_str = generate_AXI_signal(matchAWARCACHE,
|
|
"out_{0}.{1}.bits.cache := " + out_str)
|
|
elif matchAWARPROT:
|
|
assign_str = generate_AXI_signal(matchAWARPROT,
|
|
"out_{0}.{1}.bits.prot := " + out_str)
|
|
elif matchAWARQOS:
|
|
assign_str = generate_AXI_signal(matchAWARQOS,
|
|
"out_{0}.{1}.bits.qos := " + out_str)
|
|
elif matchAWARREGION:
|
|
assign_str = generate_AXI_signal(matchAWARREGION,
|
|
"//out_{0}.{1}.bits.region := " + out_str)
|
|
elif matchAWARUSER:
|
|
assign_str = ""
|
|
elif matchWDATA:
|
|
assign_str = generate_AXI_signal(matchWDATA,
|
|
"out_{0}.{1}.bits.data := " + out_str)
|
|
elif matchWSTRB:
|
|
assign_str = generate_AXI_signal(matchWSTRB,
|
|
"out_{0}.{1}.bits.strb := " + out_str)
|
|
elif matchWLAST:
|
|
assign_str = generate_AXI_signal(matchWLAST,
|
|
"out_{0}.{1}.bits.last := " + out_str)
|
|
elif matchWID:
|
|
# TODO check if this is needed
|
|
#assign_str = generate_AXI_signal(matchWID,
|
|
# "out_{0}.{1}.bits.id := " + out_str)
|
|
# No such signal in TLtoAXI4
|
|
assign_str = ""
|
|
elif matchWUSER:
|
|
assign_str = ""
|
|
elif matchAWWARREADY:
|
|
assign_str = generate_AXI_signal(matchAWWARREADY,
|
|
"slave_in.{1}.ready := " + out_str)
|
|
elif matchRBVALID:
|
|
assign_str = generate_AXI_signal(matchRBVALID,
|
|
"slave_in.{1}.valid := " + out_str)
|
|
elif matchRDATA:
|
|
assign_str = generate_AXI_signal(matchRDATA,
|
|
"slave_in.{1}.bits.data := " + out_str)
|
|
elif matchRBRESP:
|
|
assign_str = generate_AXI_signal(matchRBRESP,
|
|
"slave_in.{1}.bits.resp := " + out_str)
|
|
|
|
assert(assign_str is not None)
|
|
bus_assign_arr.append(assign_str)
|
|
ret_str += " "
|
|
ret_str += "\n ".join(bus_assign_arr)
|
|
|
|
|
|
# add return stmt if any
|
|
if 'ap_return' in list(outputs.keys()):
|
|
ret_str += "\n val ap_return = accel.io.ap.rtn\n"
|
|
|
|
return ret_str
|
|
|
|
|
|
def generate_tl_trait_stmt(func, buses):
|
|
ret_str = ""
|
|
template = Template("""
|
|
sbus.coupleFrom(s"port_named_$$axi_m_portName") {
|
|
( _
|
|
:= TLBuffer(BufferParams.default)
|
|
:= TLFIFOFixer(TLFIFOFixer.all)
|
|
:= TLWidthWidget(${M_AXI_DATA_WIDTH} >> 3)
|
|
:= AXI4ToTL()
|
|
:= AXI4UserYanker(Some(8))
|
|
:= AXI4Fragmenter()
|
|
:= AXI4IdIndexer(1)
|
|
:= hls_tl0_vadd_tl_vadd_accel.node_gmem0
|
|
)
|
|
}
|
|
""")
|
|
|
|
for k, v in buses.items():
|
|
d = {'FUNC': func, 'BUS_NAME': k, 'M_AXI_DATA_WIDTH': v['width']}
|
|
ret_str += template.substitute(d)
|
|
return ret_str
|
|
|
|
|
|
def generate_chisel_tl(func, idx, inputs, outputs, params, buses, scala_dir, template_dir):
|
|
##########################################################
|
|
logger.info("Generating TL BlackBox file ...")
|
|
template_name = 'chisel_tl_blackbox_scala_template'
|
|
template_path = template_dir / template_name
|
|
|
|
# Generate parameters
|
|
params_arr = generate_params(params)
|
|
params_str = "\n ".join(params_arr)
|
|
|
|
# Generate arguments
|
|
args_arr = generate_args(inputs, outputs)
|
|
args_str = "\n ".join(args_arr)
|
|
|
|
chisel_dict = {
|
|
"FUNC": func,
|
|
"PARAMS": params_str,
|
|
"ARGS": args_str,
|
|
}
|
|
scala_path = scala_dir / pathlib.Path(func + '_blackbox.scala')
|
|
util.generate_file(template_path, chisel_dict, scala_path)
|
|
logger.info("\t\tGenerate tl_blackbox code in CHISEL: {}".format(scala_path))
|
|
|
|
##########################################################
|
|
logger.info("Generating TL Control file ...")
|
|
template_name = 'chisel_tl_accel_scala_template'
|
|
template_path = template_dir / template_name
|
|
|
|
# Add dummy bus
|
|
if len(buses) < 1:
|
|
buses['gmem_dummy'] = {'width': 32}
|
|
|
|
axi_master_stmt_str = generate_tl_assignment(buses)
|
|
axi_module_stmt_str = generate_tl_module_stmt(inputs, outputs, buses)
|
|
# TODO test multi bundles
|
|
axi_trait_stmt_str = generate_tl_trait_stmt(func, buses)
|
|
|
|
chisel_dict = {
|
|
"FUNC": func,
|
|
"BASE_ADDR": idx,
|
|
"S_AXI_DATA_WIDTH": params['C_S_AXI_DATA_WIDTH'],
|
|
"AXI_MASTER_STMT": axi_master_stmt_str,
|
|
"AXI_MODULE_STMT": axi_module_stmt_str,
|
|
"AXI_TRAIT_STMT": axi_trait_stmt_str,
|
|
}
|
|
scala_path = scala_dir / pathlib.Path(func + '_accel.scala')
|
|
util.generate_file(template_path, chisel_dict, scala_path)
|
|
logger.info("\t\tGenerate tl_accel code in CHISEL: {}".format(scala_path))
|
|
|
|
def generate_chisel(accel_conf):
|
|
from .. import util
|
|
template_dir = util.getOpt('template-dir')
|
|
for accel in accel_conf.rocc_accels:
|
|
logger.info("\tRun CHISEL generation for {}:".format(accel.name))
|
|
inputs, outputs = parse_verilog_rocc(
|
|
accel.verilog_dir / (accel.name + ".v"))
|
|
generate_chisel_rocc( accel.name, accel.rocc_insn_id, inputs, outputs, accel.scala_dir, template_dir)
|
|
|
|
for accel in accel_conf.tl_accels:
|
|
logger.info("\tRun CHISEL generation for {}:".format(accel.name))
|
|
inputs, outputs, params, buses = parse_verilog_tl(
|
|
accel.verilog_dir / (accel.name + ".v"))
|
|
generate_chisel_tl( accel.name, accel.base_addr, inputs, outputs, params, buses, accel.scala_dir, template_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate Chisel wrappers for a given centrifuge-generated function.")
|
|
|
|
parser.add_argument('-f', '--func', required=True, help="Name of function to accelerate")
|
|
parser.add_argument('-b', '--base', required=True, help="Base address of function (if tilelink), RoCC index (if rocc)")
|
|
parser.add_argument('-p', '--prefix', default="", help="Optional prefix for function")
|
|
parser.add_argument('-m', '--mode', required=True,
|
|
help="Function integration mode (either 'tl' or 'rocc')")
|
|
parser.add_argument('-s', '--source', required=True, type=pathlib.Path,
|
|
help="Path to the source directory to use when generating (e.g. 'centrifuge/accel/hls_example_func/').")
|
|
|
|
args = parser.parse_args()
|
|
scala_dir = args.source / 'src' / 'main' / 'scala'
|
|
scala_dir.mkdir(exist_ok=True)
|
|
|
|
import sys
|
|
sys.path.append("..")
|
|
import util
|
|
|
|
module_name = pathlib.Path(__file__).stem
|
|
util.setup_logging(module_name, logger)
|
|
|
|
ctx = util.initConfig()
|
|
template_dir = util.getOpt('template-dir')
|
|
|
|
if args.mode == 'tl':
|
|
inputs, outputs, params, buses = parse_verilog_tl(
|
|
args.source / 'src' / 'main' / 'verilog' / (args.prefix + args.func + ".v"))
|
|
generate_chisel_tl(args.prefix + args.func, args.base, inputs, outputs, params, buses, scala_dir, template_dir)
|
|
elif args.mode == 'rocc':
|
|
inputs, outputs = parse_verilog_rocc(
|
|
args.source / 'src' / 'main' / 'verilog' / (args.prefix + args.func + ".v"))
|
|
generate_chisel_rocc( args.prefix + args.func, args.base, inputs, outputs, scala_dir, template_dir)
|
|
else:
|
|
raise NotImplementedError("Mode '" + args.mode + "' not supported.")
|
|
|