DeepBurning-MixQ/cifar/export_hls.py

402 lines
17 KiB
Python
Raw Permalink Normal View History

import argparse
import time
from typing import Dict, List
import torch
import numpy as np
import sys
import os
import sys
sys.path.append('..')
import models
from utils.view_pt import select_weight_file
from anypacking.quant_module import HWGQ, QuantConv2d, ImageInputQ, QuantLinear
class ConvParam: ...
def write_hls_config(model_param, path):
name_mapping = {
'k': 'K',
#'s': 'S',
#'p': 'P',
'ich': 'IFM_CH',
2023-01-16 10:24:27 +08:00
'irow': 'IFM_ROW',
'icol': 'IFM_COL',
'och': 'OFM_CH',
'orow': 'OFM_ROW',
'ocol': 'OFM_COL',
'abit': 'IN_BIT',
'wbit': 'W_BIT',
'incbit': 'INC_BIT',
'biasbit': 'BIAS_BIT',
'simd': 'SIMD',
'pe': 'PE',
'lshift': 'L_SHIFT'
}
content = f'''/********************************************************************************
* Filename: config.h
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
* ptfilename: {opt.weight}
********************************************************************************/
#ifndef _CONFIG_H_
#define _CONFIG_H_
'''
for n, conv_param in enumerate(model_param):
content += f'// {conv_param.type}_{n}\n'
for k, v in name_mapping.items():
if hasattr(conv_param, k): # e.g. conv_last has no incbit
content += f'#define {conv_param.type.upper()}_{n}_{v} {getattr(conv_param, k)}\n'
content += '\n'
2023-01-16 10:24:27 +08:00
content += '#endif'
with open(path + 'config.h', 'w') as f:
print(content, file=f)
def extract_model(in_shape):
model_param: List[ConvParam] = []
feature_map_shape = in_shape
conv_cnt = 0
conv_cur = None
for sub_module in model.modules():
# expect [QAct] -> [Pooling] -> Conv -> [BN] -> [Pooling], state machine mode
if isinstance(sub_module, HWGQ) or isinstance(sub_module, ImageInputQ):
print(' Detected ActQ Layer', end='')
if conv_cur is None: conv_cur = ConvParam()
conv_cur.abit = sub_module.bit
conv_cur.astep = sub_module.step
conv_cur.actq_class = type(sub_module).__name__
print(f', abit {conv_cur.abit}, astep {conv_cur.astep}, class {conv_cur.actq_class}')
if conv_cnt: # previous.obit = cur.abit
model_param[conv_cnt-1].obit = conv_cur.abit
model_param[conv_cnt-1].ostep = conv_cur.astep
elif isinstance(sub_module, torch.nn.Conv2d):
if conv_cur is None: conv_cur = ConvParam()
conv_cur.n = conv_cnt
print('Extract conv_%d'%conv_cnt, end='')
conv_cur.k = sub_module.kernel_size[0]
conv_cur.s = sub_module.stride[0]
conv_cur.p = sub_module.padding[0]
conv_cur.ich = sub_module.in_channels
conv_cur.och = sub_module.out_channels
conv_cur.irow = feature_map_shape[1]
conv_cur.icol = feature_map_shape[2]
feature_map_shape[0] = sub_module.out_channels
feature_map_shape[1] = (feature_map_shape[1] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
feature_map_shape[2] = (feature_map_shape[2] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
conv_cur.orow = feature_map_shape[1]
conv_cur.ocol = feature_map_shape[2]
assert sub_module.bias is None, 'inner conv has no bias in this model'
if isinstance(sub_module, QuantConv2d): # New quant
conv_cur.wbit = sub_module.bit
2023-03-04 00:40:00 +08:00
conv_cur.w, conv_cur.wstep = sub_module.export_quant() # wstep is not QuantConv2d.step because of alpha
else:
raise NotImplementedError(sub_module)
print(', ich {ich}, och {och}, irow {irow}, icol {icol}, ksp {k}{s}{p}, wbit {wbit}, wstep {wstep}'.format(**vars(conv_cur)))
conv_cur.type = 'conv'
model_param.append(conv_cur)
conv_cur = None
conv_cnt += 1
elif isinstance(sub_module, torch.nn.Linear):
if conv_cur is None: conv_cur = ConvParam() # TODO: independent type for linear layer
conv_cur.n = conv_cnt
print('Extract layer %d (linear layer)'%conv_cnt, end='')
conv_cur.ich = sub_module.in_features
conv_cur.och = sub_module.out_features
conv_cur.irow = feature_map_shape[1]
conv_cur.icol = feature_map_shape[2]
if sub_module.bias is not None:
conv_cur.convbias = sub_module.bias.detach().numpy()
print(', +bias', end='')
if isinstance(sub_module, QuantLinear): # New quant
conv_cur.wbit = sub_module.bit
2023-03-04 00:40:00 +08:00
conv_cur.w, conv_cur.wstep = sub_module.export_quant() # wstep is not QuantLinear.step because of alpha
print(', ich {ich}, och {och}, wbit {wbit}, wstep {wstep}'.format(**vars(conv_cur)))
conv_cur.type = 'linear'
model_param.append(conv_cur)
conv_cur = None
conv_cnt += 1
elif isinstance(sub_module, torch.nn.BatchNorm2d):
print(' Detected BatchNorm2d')
gamma = sub_module.weight
beta = sub_module.bias
mean = sub_module.running_mean
var = sub_module.running_var
eps = sub_module.eps
model_param[-1].bn_w = (gamma / (torch.sqrt(var + eps))).detach().numpy()
model_param[-1].bn_b = (beta - (mean / (torch.sqrt(var + eps)) * gamma)).detach().numpy()
elif isinstance(sub_module, torch.nn.MaxPool2d):
print(' Detected MaxPool2d')
feature_map_shape[1] = feature_map_shape[1] // sub_module.kernel_size
feature_map_shape[2] = feature_map_shape[2] // sub_module.kernel_size
assert hasattr(model_param[0], 'abit')
return model_param
def process_batchnorm(model_param):
'''process_batchnorm(model_param)
Merge wstep, astep, ostep scale into batchnorm, then quantize.
Method:
Define MAC = Conv(w, a), out = MAC*BN_w + BN_b,
wq = w/wstep, aq = a/astep, MACq = MAC/MACstep, outq = out/ostep.
outq = (MAC*BN_w + BN_b) / ostep
= MACq * (MACstep/ostep)*BN_w + BN_b/ostep
= MACq * inc_raw + bias_raw
next layer activation a' = ActQ(out), i.e. a'q = clip(round(outq))
Quantiaztion of inc_raw & bias_raw:
outq_real = round((MACq*round(inc_raw*scale) + round(bias_raw*scale)) / scale) ; where scale=2**T
= (MACq*round(inc_raw*scale) + round(bias_raw*scale) + 0.5 * scale) // scale ; div floor
= (MACq* inc + bias + 2**(T-1) ) >> T ; [!] the 2**(T-1) bias is done by hls code
Params:
T = (wbit-1)+abit+lshift # This comes from dorefa quant, not optimal
MBIT = wbit+abit+ceil(log2(sum_number))
incbit = len(bit(inc)); biasbit = len(bit(bias))
larger lshift is better, but MBIT+incbit<48
'''
lshift = 16
for conv in model_param[:-1]:
print(f'Process bn_{conv.n}, shape {conv.bn_w.shape},', end = ' ')
# Merge step to BN
conv.lshift = lshift
MACstep = conv.wstep * conv.astep
ostep = conv.ostep
inc_raw = conv.bn_w * MACstep / ostep
bias_raw = conv.bn_b / ostep
conv.inc_raw = inc_raw
conv.bias_raw = bias_raw
# Quantization
T = lshift+conv.wbit+conv.abit-1
conv.inc = np.round(inc_raw * 2**T).astype(np.int64)
conv.bias = np.round(bias_raw * 2**T).astype(np.int64)
conv.lshift_T = T
# Get bitlength
bitlength = lambda x: 1 + int(np.abs(x).max()).bit_length()
conv.incbit = bitlength(conv.inc)
conv.biasbit = bitlength(conv.bias)
print(f'incbit {conv.incbit}, biasbit {conv.biasbit}, lshift_T {conv.lshift_T}')
conv_last = model_param[-1] # process lastbias
conv_last.inc = None
conv_last.div = 1/(conv_last.wstep * conv_last.astep)
conv_last.bias = np.round(conv_last.convbias * conv_last.div).astype(np.int64)
conv_last.biasbit = bitlength(conv_last.bias)
print(f'conv_last biasbit {conv_last.biasbit}, div {conv_last.div}')
def reorder_weight(model_param, layers_simd, layers_pe):
'''reorder_weight(model_param)
Reorder array for hlscode.
'''
for conv in model_param:
2023-01-16 10:24:27 +08:00
if conv.type == 'linear': #new reorder
pe_l = 1
simd_l = 1
in_pe_l = 8
w = conv.w.reshape(10, -1, 4, 4)
w = w.reshape(10 // (2 * pe_l), pe_l, 2, 256 // in_pe_l, in_pe_l // simd_l, simd_l, 4, 4) #[OUT_CH/2PE, PE, 2, IN_CH/IN_PE, IN_PE/SIMD, SIMD, H, W]
w = w.transpose(1, 6, 3, 7, 0, 4, 5, 2) #[PE, H, IN_CH/IN_PE, W, OUT_CH/2PE, IN_PE/SIMD, SIMD, 2]
w = w.reshape(w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4], w.shape[5], -1) #[PE, H, IN_CH/IN_PE, W, OUT_CH/2PE, IN_PE/SIMD, SIMD * 2]
print(w.shape)
conv.w = w
continue
print(f'Reorder conv_{conv.n}, w {conv.w.shape}', end='')
conv.simd = layers_simd[conv.n]
conv.pe = layers_pe[conv.n]
# process batchnorm
if conv.inc is not None:
conv.inc = conv.inc.reshape(conv.och//conv.pe, conv.pe).T
if conv.bias is not None:
conv.bias = conv.bias.reshape(conv.och//conv.pe, conv.pe).T
# process conv weight
w = conv.w # [och, ich, kr, kc]
assert conv.och%conv.pe == 0, f"conv_{conv.n}, och {conv.och}, pe {conv.pe}"
assert conv.k*conv.ich%conv.simd == 0, f"conv_{conv.n}, ich {conv.ich}, k {conv.k}, simd {conv.simd}"
# if conv.n==0: # first layer is different
# w = w.transpose(0, 2, 3, 1) # [och, kr, kc, ich]
# else:
w = w.transpose(0, 3, 2, 1) # [och, kc, kr, ich]
w = w.reshape(conv.och//conv.pe, conv.pe, conv.k, conv.k*conv.ich//conv.simd, conv.simd)
w = w.transpose(1,2,0,3,4) # [pe, k, och/pe, k*ich/simd, simd]
w = w.reshape(conv.pe, conv.k, -1, conv.simd) # hls format [pe, k, och/pe*k*ich/simd, simd]
if conv.k == 1: # kernel size=1
w = w.reshape(conv.pe, -1, conv.simd)
print(' ->', w.shape)
conv.w = w
def print_ndarray_recursion(arr, str_func=str, file=sys.stdout, stop=0):
if not hasattr(arr, '__iter__') or len(arr.shape) == stop:
print(str_func(arr), file=file, end='')
return
ends = '' if (len(arr.shape)==stop+1) else '\n'
print('{', file=file, end='')
for i, item in enumerate(arr):
print_ndarray_recursion(item, str_func, file, stop)
if i!=len(arr)-1: print(',', file=file, end=ends)
print(ends+'}', file=file, end='')
def write_hls_linearlayer(layer, f):
n = layer.n
print(f"// layer: {n}, wbit: {layer.wbit}", file=f)
hex_str = lambda x: '"' + hex(x) + '"'
print(f"const ap_int<{layer.wbit}> linear_{n}_w[{layer.och}][{layer.ich}]=", file=f)
print_ndarray_recursion(layer.w, hex_str, f)
print(';', file=f)
if layer.bias is not None:
print(f"const ap_int<{layer.biasbit}> linear_{n}_bias[{layer.och}]=", file=f)
print_ndarray_recursion(layer.bias, hex_str, f)
print(';', file=f)
2023-01-16 10:24:27 +08:00
def write_hls_linearlayer_reorder(layer, d0, d1, d2, d3, d4, d5, d6, f):
n = layer.n
print(f"// layer: {n}, wbit: {layer.wbit}", file=f)
hex_str = lambda x: '"' + hex(x) + '"'
def pack1d_str(arr): # x: 1d-array
x = 0
# print(arr.shape)
for v in arr[::-1]: # [!] reverse simd pack, it is related to hls implemention
v = int(v) # use python bignumber, not np.int
assert -1<<layer.wbit-1 <= v < 1<<layer.wbit-1, f'got v={v} while wbit={layer.wbit}'
x=(x<<layer.wbit) + (v&(2**layer.wbit-1))
return hex_str(x)
print(f"const ap_uint<{layer.wbit * d6}> linear_{n}_w[{d0}][{d1}][{d2}][{d3}][{d4}][{d5}]=", file=f)
print_ndarray_recursion(layer.w, pack1d_str, f, stop=1)
print(';', file=f)
if layer.bias is not None:
print(f"const ap_int<{layer.biasbit}> linear_{n}_bias[{layer.och}]=", file=f)
print_ndarray_recursion(layer.bias, hex_str, f)
print(';', file=f)
def write_hls_weights(model_param, path):
'''write_hls_weights(model_param, path)
Write hls weights+inc+bias array code according to numpy shape.
'''
f = open(path + 'weights.hpp', 'w')
print(f'''/********************************************************************************
* Filename: weights.hpp
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
* ptfilename: {opt.weight}
********************************************************************************/
#ifndef _WEIGHTS_HPP_
#define _WEIGHTS_HPP_
#include <ap_int.h>
''', file=f)
for conv in model_param:
if conv.type == 'linear':
2023-01-16 10:24:27 +08:00
pe_pr = conv.w.shape[0]
h_pr = conv.w.shape[1]
inch_inpe_pr = conv.w.shape[2]
w_pr = conv.w.shape[3]
outch_2pe_pr = conv.w.shape[4]
inpe_simd_pr = conv.w.shape[5]
simd2_pr = conv.w.shape[6]
write_hls_linearlayer_reorder(conv, pe_pr, h_pr, inch_inpe_pr, w_pr, outch_2pe_pr, inpe_simd_pr, simd2_pr, f)
continue
n = conv.n
print(f"Write conv_{n} weight, pe {conv.pe}, simd {conv.simd}, wbit {conv.wbit}")
print(f"// layer: {n}, PE: {conv.pe}, SIMD: {conv.simd}, wbit: {conv.wbit}", file=f)
# print conv weight, merge [SIMD] value into one ap_uint
if conv.k>1:
print(f"const ap_uint<{conv.wbit * conv.simd}> conv_{n}_w[{conv.pe}][{conv.k}][{conv.w.shape[2]}]=", file=f)
else:
print(f"const ap_uint<{conv.wbit * conv.simd}> conv_{n}_w[{conv.pe}][{conv.w.shape[1]}]=", file=f)
hex_str = lambda x: '"' + hex(x) + '"'
def pack1d_str(arr): # x: 1d-array
x = 0
for v in arr[::-1]: # [!] reverse simd pack, it is related to hls implemention
v = int(v) # use python bignumber, not np.int
assert -1<<conv.wbit-1 <= v < 1<<conv.wbit-1, f'got v={v} while wbit={conv.wbit}'
x=(x<<conv.wbit) + (v&(2**conv.wbit-1))
return hex_str(x)
print_ndarray_recursion(conv.w, pack1d_str, f, stop=1)
print(';', file=f)
# print inc, bias
if conv.inc is not None:
print(f"const ap_int<{conv.incbit}> conv_{n}_inc[{conv.pe}][{conv.och//conv.pe}]=", file=f)
print_ndarray_recursion(conv.inc, hex_str, f)
print(';', file=f)
if conv.bias is not None:
print(f"const ap_int<{conv.biasbit}> conv_{n}_bias[{conv.pe}][{conv.och//conv.pe}]=", file=f)
print_ndarray_recursion(conv.bias, hex_str, f)
print(';', file=f)
print('#endif', file=f)
f.close()
def adjust_weight(model_param):
2023-01-16 10:24:27 +08:00
# special_wa_bit = ((5,6), (7,3)) # These packing can't quantize to -2**(wbit-1)
special_wa_bit = ((4, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (7, 2), (7, 3)) # These packing can't quantize to -2**(wbit-1)
for conv in model_param:
if (conv.wbit, conv.abit) in special_wa_bit:
print(f'Adjust conv_{conv.n} wbit={conv.wbit}')
conv.w = np.maximum(conv.w, -2**(conv.wbit-1)+1)
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-w', '--weight', default=None, help='.pt file name in ./weights/')
parser.add_argument('-m', '--model', default='VGG_tiny_FixQ', help = 'model class name in models.py')
parser.add_argument('-c', '--config-simd-pe', default='config_simd_pe', help = '.txt file in ./hls/')
opt = parser.parse_args()
if opt.weight is None: opt.weight = select_weight_file()
simd_pe = np.loadtxt('hls/'+opt.config_simd_pe+'.txt', dtype=int, skiprows=1)
dir_output = 'hls/' + opt.weight + '/'
if not os.path.exists(dir_output): os.makedirs(dir_output)
# load model and state_dict
ptfile:Dict = torch.load('weights/' + opt.weight + '.pt', map_location='cpu')
model = getattr(models, opt.model)(**ptfile.setdefault('model_params', {}))
2023-03-04 00:40:00 +08:00
model.load_state_dict(ptfile['model'], strict = False)
# processs
model_param = extract_model([1, 32, 32])
adjust_weight(model_param)
process_batchnorm(model_param) # get bn param before write hls config
torch.save(model_param, dir_output + 'model_param.pkl')
reorder_weight(model_param, simd_pe[:,0], simd_pe[:,1]) # get pe, simd param before write hls config
write_hls_config(model_param, dir_output)
write_hls_weights(model_param, dir_output)