add bypass model; new packing factor

This commit is contained in:
fffasttime 2022-12-06 17:22:42 +08:00
parent 71737900d1
commit 469a4024b6
8 changed files with 371 additions and 25 deletions

View File

@ -9,13 +9,13 @@ gaussian_steps = {1: 1.596, 2: 0.996, 3: 0.586, 4: 0.336, 5: 0.190, 6: 0.106, 7:
hwgq_steps = {1: 0.799, 2: 0.538, 3: 0.3217, 4: 0.185, 5: 0.104, 6: 0.058, 7: 0.033, 8: 0.019}
dsp_factors=[
[18,15,12,6,6,6,6],
[15,12,6,6,6,6,3],
[12,12,12,6,6,6,6],
[12,12,6,6,6,6,3],
[9,6,6,6,6,6,3],
[9,6,6,6,6,3,3],
[6,6,6,3,3,3,3],
[6,6,3,3,3,3,2],
[4.5,3,3,3,3,3,2],
[6,3,3,3,3,3,2],
]
class _gauss_quantize_sym(torch.autograd.Function):

View File

@ -113,6 +113,7 @@ if __name__ == '__main__':
parser.add_argument('-w', '--weights', default=None, help='weights path')
parser.add_argument('-e', '--epochs', type=int, default=40)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--bypass', action='store_true', help='use bypass model')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--lr', type=float, default=0.03)
parser.add_argument('--bitw', type=str, default='')

View File

@ -305,7 +305,8 @@ def write_hls_weights(model_param, path):
f.close()
def adjust_weight(model_param):
special_wa_bit = ((5,6), (7,3)) # These packing can't quantize to -2**(wbit-1)
special_wa_bit = ((4,2),(5,3),(5,4),(5,5),(5,6),(5,7),(5,8),(7,2),(7,3))
# These packing can't quantize to -2**(wbit-1)
for conv in model_param:
if (conv.wbit, conv.abit) in special_wa_bit:
print(f'Adjust conv_{conv.n} wbit={conv.wbit}')

View File

@ -59,10 +59,13 @@ def train():
test_path = localconfig.test_path
nc = 1
results_file = 'result_%s.txt'%opt.name
results_file = 'results/%s.txt'%opt.name
# Initialize model
model = UltraNet_FixQ(opt.bitw, opt.bita).to(device)
if opt.bypass:
model = UltraNetBypass_FixQ(opt.bitw, opt.bita).to(device)
else:
model = UltraNet_FixQ(opt.bitw, opt.bita).to(device)
# Optimizer
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
@ -270,10 +273,10 @@ def train():
'extra': {'time': time.ctime(), 'name': opt.name}}
# Save last checkpoint
torch.save(chkpt, wdir + 'last_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
if test_iou == test_best_iou:
torch.save(chkpt, wdir + 'test_best_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_best.pt'%opt.name)
# Delete checkpoint
del chkpt
@ -292,6 +295,7 @@ def train():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--bypass', action='store_true', help='use bypass model')
parser.add_argument('--epochs', type=int, default=200) # 500200 batches at bs 16, 117263 COCO images = 273 epochs
parser.add_argument('--batch-size', type=int, default=64) # effective bs = batch_size * accumulate = 16 * 4 = 64
parser.add_argument('--accumulate', type=int, default=1, help='batches to accumulate before optimizing')

View File

@ -465,4 +465,318 @@ class UltraNetFloat(nn.Module):
else: # test
io, p = zip(*yolo_out) # inference output, training output
return torch.cat(io, 1), p
return x
return x
class ReorgLayer(nn.Module):
def __init__(self, stride=2):
super(ReorgLayer, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
assert(H % stride == 0)
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view([B, C, H//hs, hs, W//ws, ws]).transpose(3, 4).contiguous()
x = x.view([B, C, H//hs*W//ws, hs*ws]).transpose(2, 3).contiguous()
x = x.view([B, C, hs*ws, H//hs, W//ws]).transpose(1, 2).contiguous()
x = x.view([B, hs*ws*C, H//hs, W//ws])
return x
class UltraNetBypassFloat(nn.Module):
def __init__(self):
super(UltraNetBypassFloat, self).__init__()
self.reorg = ReorgLayer(stride=2)
self.layers_p1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2)
)
self.layers_p2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.layers_p3 = nn.Sequential(
nn.MaxPool2d(2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.layers_p4 = nn.Sequential(
nn.Conv2d(320, 64, kernel_size=3, stride=1, padding=1, bias=False), # cat p2--64→64*4 + 64
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 36, kernel_size=1, stride=1, padding=0)
)
self.yololayer = YOLOLayer([[20, 20], [20, 20], [20, 20], [20, 20], [20, 20], [20, 20]])
self.yolo_layers = [self.yololayer]
def forward(self, x):
img_size = x.shape[-2:]
yolo_out, out = [], []
x_p1 = self.layers_p1(x)
x_p2 = self.layers_p2(x_p1)
x_p2_reorg = self.reorg(x_p2)
x_p3 = self.layers_p3(x_p2)
x_p4_in = torch.cat([x_p2_reorg, x_p3], 1)
x_p4 = self.layers_p4(x_p4_in)
x = self.yololayer(x_p4, img_size)
yolo_out.append(x)
if self.training: # train
return yolo_out
else: # test
io, p = zip(*yolo_out) # inference output, training output
return torch.cat(io, 1), p
return x
class UltraNetBypass_MixQ(nn.Module):
def __init__(self, share_weight = False):
super(UltraNetBypass_MixQ, self).__init__()
self.reorg = ReorgLayer(stride=2)
self.conv_func = qm.MixActivConv2d
conv_func = self.conv_func
conv_kwargs = {'kernel_size':3, 'stride':1, 'padding':1, 'bias':False}
qspace = {'wbits':[2,3,4,5,6,7,8], 'abits':[2,3,4,5,6,7,8], 'share_weight':share_weight}
self.layers_p1 = nn.Sequential(
conv_func(3, 16, ActQ = qm.ImageInputQ, **conv_kwargs, **qspace),
nn.BatchNorm2d(16),
nn.MaxPool2d(2, stride=2),
conv_func(16, 32, **conv_kwargs, **qspace),
nn.BatchNorm2d(32),
nn.MaxPool2d(2, stride=2),
conv_func(32, 64, **conv_kwargs, **qspace),
nn.BatchNorm2d(64),
nn.MaxPool2d(2, stride=2)
)
self.layers_p2 = nn.Sequential(
conv_func(64, 64, **conv_kwargs, **qspace),
nn.BatchNorm2d(64),
)
self.layers_p3 = nn.Sequential(
nn.MaxPool2d(2, stride=2),
conv_func(64, 64, **conv_kwargs, **qspace),
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, **qspace),
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, **qspace),
nn.BatchNorm2d(64),
)
self.layers_p4 = nn.Sequential(
conv_func(320, 64, **conv_kwargs, **qspace), # cat p2--64→64*4 + 64
nn.BatchNorm2d(64),
conv_func(64, 36, kernel_size = 1, stride =1, padding = 0, **qspace)
)
self.yololayer = YOLOLayer([[20, 20], [20, 20], [20, 20], [20, 20], [20, 20], [20, 20]])
self.yolo_layers = [self.yololayer]
def forward(self, x):
img_size = x.shape[-2:]
yolo_out, out = [], []
x_p1 = self.layers_p1(x)
x_p2 = self.layers_p2(x_p1)
x_p2_reorg = self.reorg(x_p2)
x_p3 = self.layers_p3(x_p2)
x_p4_in = torch.cat([x_p2_reorg, x_p3], 1)
x_p4 = self.layers_p4(x_p4_in)
x = self.yololayer(x_p4, img_size)
yolo_out.append(x)
if self.training: # train
return yolo_out
else: # test
io, p = zip(*yolo_out) # inference output, training output
return torch.cat(io, 1), p
return x
def fetch_best_arch(self):
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
sum_mixbitops, sum_mixbita, sum_mixbitw, sum_mixdsps = 0, 0, 0, 0
layer_idx = 0
best_arch = None
for m in self.modules():
if isinstance(m, self.conv_func):
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx)
if best_arch is None:
best_arch = layer_arch
else:
for key in layer_arch.keys():
if key not in best_arch:
best_arch[key] = layer_arch[key]
else:
best_arch[key].append(layer_arch[key][0])
sum_bitops += bitops
sum_bita += bita
sum_bitw += bitw
sum_mixbitops += mixbitops
sum_mixbita += mixbita
sum_mixbitw += mixbitw
sum_dsps += dsps
sum_mixdsps += mixdsps
layer_idx += 1
return best_arch, sum_bitops, sum_bita, sum_bitw, sum_mixbitops, sum_mixbita, sum_mixbitw, sum_dsps, sum_mixdsps
def complexity_loss(self):
size_product = []
loss = 0
for m in self.modules():
if isinstance(m, self.conv_func):
loss += m.complexity_loss()
size_product += [m.size_product]
normalizer = size_product[0].item()
loss /= normalizer
return loss
class UltraNetBypass_FixQ(nn.Module):
def __init__(self, bitw = '444444444', bita = '444444444'):
super(UltraNetBypass_FixQ, self).__init__()
self.reorg = ReorgLayer(stride=2)
self.conv_func = qm.QuantActivConv2d
conv_func = self.conv_func
assert(len(bitw)==0 or len(bitw)==9)
assert(len(bita)==0 or len(bita)==9)
if isinstance(bitw, str):
bitw=list(map(int, bitw))
if isinstance(bita, str):
bita=list(map(int, bita))
self.bitw = bitw
self.bita = bita
self.model_params = {'bitw': bitw, 'bita': bita}
conv_kwargs = {'kernel_size':3, 'stride':1, 'padding':1, 'bias':False}
self.layers_p1 = nn.Sequential(
conv_func(3, 16, ActQ = qm.ImageInputQ, **conv_kwargs, wbit=bitw[0], abit=bita[0]),
nn.BatchNorm2d(16),
nn.MaxPool2d(2, stride=2),
conv_func(16, 32, **conv_kwargs, wbit=bitw[1], abit=bita[1]),
nn.BatchNorm2d(32),
nn.MaxPool2d(2, stride=2),
conv_func(32, 64, **conv_kwargs, wbit=bitw[2], abit=bita[2]),
nn.BatchNorm2d(64),
nn.MaxPool2d(2, stride=2)
)
self.layers_p2 = nn.Sequential(
conv_func(64, 64, **conv_kwargs, wbit=bitw[3], abit=bita[3]),
nn.BatchNorm2d(64),
)
self.layers_p3 = nn.Sequential(
nn.MaxPool2d(2, stride=2),
conv_func(64, 64, **conv_kwargs, wbit=bitw[4], abit=bita[4]),
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, wbit=bitw[5], abit=bita[5]),
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, wbit=bitw[6], abit=bita[6]),
nn.BatchNorm2d(64),
)
self.layers_p4 = nn.Sequential(
conv_func(320, 64, **conv_kwargs, wbit=bitw[7], abit=bita[7]), # cat p2--64→64*4 + 64
nn.BatchNorm2d(64),
conv_func(64, 36, kernel_size = 1, stride =1, padding = 0, wbit=bitw[8], abit=bita[8])
)
self.yololayer = YOLOLayer([[20, 20], [20, 20], [20, 20], [20, 20], [20, 20], [20, 20]])
self.yolo_layers = [self.yololayer]
def forward(self, x):
img_size = x.shape[-2:]
yolo_out, out = [], []
x_p1 = self.layers_p1(x)
x_p2 = self.layers_p2(x_p1)
x_p2_reorg = self.reorg(x_p2)
x_p3 = self.layers_p3(x_p2)
x_p4_in = torch.cat([x_p2_reorg, x_p3], 1)
x_p4 = self.layers_p4(x_p4_in)
x = self.yololayer(x_p4, img_size)
yolo_out.append(x)
if self.training: # train
return yolo_out
else: # test
io, p = zip(*yolo_out) # inference output, training output
return torch.cat(io, 1), p
return x
def fetch_arch_info(self):
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
layer_idx = 0
for m in self.modules():
if isinstance(m, self.conv_func):
size_product = m.size_product.item()
memory_size = m.memory_size.item()
bitops = size_product * m.abit * m.wbit
bita = m.memory_size.item() * m.abit
bitw = m.param_size * m.wbit
dsps = size_product / qm.dsp_factors[m.wbit-2][m.abit-2]
weight_shape = list(m.conv.weight.shape)
print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, '
'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit,
m.wbit, memory_size, m.abit, m.param_size, m.wbit, dsps))
sum_bitops += bitops
sum_bita += bita
sum_bitw += bitw
sum_dsps += dsps
layer_idx += 1
return sum_bitops, sum_bita, sum_bitw, sum_dsps

View File

@ -58,10 +58,13 @@ def train():
test_path = localconfig.test_path
nc = 1
results_file = 'result_%s.txt'%opt.name
results_file = 'results/%s.txt'%opt.name
# Initialize model
model = UltraNet_MixQ(opt.share_weight).to(device)
if opt.bypass:
model = UltraNetBypass_MixQ(not opt.no_share).to(device)
else:
model = UltraNet_MixQ(not opt.no_share).to(device)
# Optimizer
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
@ -302,10 +305,10 @@ def train():
'extra': {'time': time.ctime(), 'name': opt.name}}
# Save last checkpoint
torch.save(chkpt, wdir + 'last_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
if test_iou == test_best_iou:
torch.save(chkpt, wdir + 'test_best_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_best.pt'%opt.name)
# Delete checkpoint
del chkpt
@ -324,6 +327,7 @@ def train():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--bypass', action='store_true', help='use bypass model')
parser.add_argument('--epochs', type=int, default=35) # 500200 batches at bs 16, 117263 COCO images = 273 epochs
parser.add_argument('--batch-size', type=int, default=64) # effective bs = batch_size * accumulate = 16 * 4 = 64
parser.add_argument('--accumulate', type=int, default=1, help='batches to accumulate before optimizing')
@ -344,7 +348,7 @@ if __name__ == '__main__':
parser.add_argument('--var', type=float, help='debug variable')
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate')
parser.add_argument('--share-weight', action='store_true', help='share weight quantization')
parser.add_argument('--no-share', action='store_true', help='no share weight quantization')
opt = parser.parse_args()
last = wdir + 'last_%s.pt'%opt.name

View File

@ -24,7 +24,7 @@ class QConvLayer:
x = F.conv2d(x, self.w, bias=None, stride=self.conv.s, padding=self.conv.p) # [N, OCH, OROW, OCOL]
# print('convo', self.conv.n, x[0,0,:,0])
och = x.shape[1]
if False:
if True:
if self.conv.inc is not None:
inc_ch = self.conv.inc.reshape((1, och, 1, 1))
x *= inc_ch
@ -54,6 +54,20 @@ class QConvLayer:
return x
def reorg(x):
stride = 2
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
ws = stride
hs = stride
x = x.view([B, C, H//hs, hs, W//ws, ws]).transpose(3, 4).contiguous()
x = x.view([B, C, H//hs*W//ws, hs*ws]).transpose(2, 3).contiguous()
x = x.view([B, C, hs*ws, H//hs, W//ws]).transpose(1, 2).contiguous()
x = x.view([B, hs*ws*C, H//hs, W//ws])
return x
class HWModel:
def __init__(self, model_param):
self.layers = [QConvLayer(conv_param) for conv_param in model_param]
@ -67,8 +81,19 @@ class HWModel:
if self.layers[0].conv.abit<8: # ImageInputQ
x=x>>(8-self.layers[0].conv.abit)
for i, layer in enumerate(self.layers):
x = layer(x)
if not opt.bypass:
for i, layer in enumerate(self.layers):
x = layer(x)
else:
for i in [0,1,2,3]:
x = self.layers[i](x)
p4_in = torch.round(reorg(x) *
self.layers[4].conv.astep / self.layers[7].conv.astep).to(dtype=torch.int64)
for i in [4,5,6]:
x = self.layers[i](x)
x = torch.cat([p4_in, x], 1)
for i in [7,8]:
x= self.layers[i](x)
x = x.float() / self.layers[-1].conv.div
@ -110,6 +135,7 @@ def testdataset(hwmodel):
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-w', '--weight', help='weight folder name in ./hls/, which contians model_param.pkl')
parser.add_argument('-bp', '--bypass', action='store_true', help='use bypass model')
parser.add_argument('--datapath', default='../../dacsdc_dataset', help = 'test dataset path')
parser.add_argument('-bs', '--batch-size', type=int, default=1, help = 'batch-size')
parser.add_argument('-nb', '--num-batch', type=int, default=1, help = 'num of batchs to run, -1 for full dataset')

View File

@ -67,14 +67,14 @@ def train():
test_path = localconfig.test_path
nc = 1
results_file = 'result_%s.txt'%opt.name
results_file = 'results/%s.txt'%opt.name
# Remove previous results
for f in glob.glob('*_batch*.png') + glob.glob(results_file):
os.remove(f)
# Initialize model
# model = Darknet(cfg, arc=opt.arc).to(device)
model = UltraNetFloat().to(device)
model = UltraNetBypassFloat().to(device)
# model = SkyNet().to(device)
# model = TempNet().to(device)
# model = TempNetDW().to(device)
@ -346,14 +346,10 @@ def train():
'optimizer': None if final_epoch else optimizer.state_dict()}
# Save last checkpoint
torch.save(chkpt, wdir + 'last_%s.pt'%opt.name)
# Save best checkpoint
if best_fitness == fi:
torch.save(chkpt, wdir + 'best_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
if test_iou == test_best_iou:
torch.save(chkpt, wdir + 'test_best_%s.pt'%opt.name)
torch.save(chkpt, wdir + '%s_best.pt'%opt.name)
# Save backup every 10 epochs (optional)
# if epoch > 0 and epoch % 10 == 0: