add trivial_loss exp; add pareto_train exp; fix vgg_tiny pooling order; seperate packing_factor

This commit is contained in:
fffasttime 2023-04-03 21:54:43 +08:00
parent 0b983a1223
commit 587637c4fc
9 changed files with 174 additions and 69 deletions

29
anypacking/dsp_packing.py Normal file
View File

@ -0,0 +1,29 @@
factors_k11=[
[12,8,8,6,6,4,4],
[10,8,6,6,4,4,4],
[8,6,6,4,4,4,3],
[6,6,4,4,4,4,2],
[6,4,4,4,2,2,2],
[4,4,4,4,2,2,2],
[4,4,3,2,2,2,2],
]
factors_k33=[
[18,15,12,7.5,7.5,6,6],
[15,12,7.5,6,6,6,3],
[12,7.5,6,6,6,6,3],
[9,6,6,6,6,3,3],
[7.5,6,6,4.5,3,3,3],
[6,6,4.5,3,3,3,2.25],
[6,3,3,3,3,3,2],
]
factors_k55=[
[20,15,10,7.5,7.5,5,5],
[12.5,10,6.67,5,5,5,3.33],
[10,7.5,5,5,5,5,3.33],
[7.5,6.67,5,5,5,3.33,3.33],
[6.67,5,5,5,3.33,2.5,2.5],
[5,5,5,3.33,2.5,2.5,2.5],
[5,3.33,3.33,3.33,2.5,2.5,2],
]

View File

@ -5,40 +5,12 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
# load dsp packing factors
from .dsp_packing import *
gaussian_steps = {1: 1.596, 2: 0.996, 3: 0.586, 4: 0.336, 5: 0.190, 6: 0.106, 7: 0.059, 8: 0.032}
hwgq_steps = {1: 0.799, 2: 0.538, 3: 0.3217, 4: 0.185, 5: 0.104, 6: 0.058, 7: 0.033, 8: 0.019}
dsp_factors_k11=[
[12,8,8,6,6,4,4],
[10,8,6,6,4,4,4],
[8,6,6,4,4,4,3],
[6,6,4,4,4,4,2],
[6,4,4,4,2,2,2],
[4,4,4,4,2,2,2],
[4,4,3,2,2,2,2],
]
dsp_factors_k33=[
[18,15,12,7.5,7.5,6,6],
[15,12,7.5,6,6,6,3],
[12,7.5,6,6,6,6,3],
[9,6,6,6,6,3,3],
[7.5,6,6,4.5,3,3,3],
[6,6,4.5,3,3,3,2.25],
[6,3,3,3,3,3,2],
]
dsp_factors_k55=[
[20,15,10,7.5,7.5,5,5],
[12.5,10,6.67,5,5,5,3.33],
[10,7.5,5,5,5,5,3.33],
[7.5,6.67,5,5,5,3.33,3.33],
[6.67,5,5,5,3.33,2.5,2.5],
[5,5,5,3.33,2.5,2.5,2.5],
[5,3.33,3.33,3.33,2.5,2.5,2],
]
class _gauss_quantize_sym(torch.autograd.Function):
@staticmethod
@ -394,7 +366,7 @@ class MixActivConv2d(nn.Module):
out = self.mix_weight(out)
return out
def complexity_loss_old(self):
def complexity_loss_trivial(self):
sw = F.softmax(self.mix_activ.alpha_activ, dim=0)
mix_abit = 0
abits = self.mix_activ.bits
@ -416,16 +388,16 @@ class MixActivConv2d(nn.Module):
wbits = self.mix_weight.bits
if self.kernel_size == 1:
dsp_factors = dsp_factors_k11
factors = factors_k11
elif self.kernel_size == 3:
dsp_factors = dsp_factors_k33
factors = factors_k33
elif self.kernel_size == 5:
dsp_factors = dsp_factors_k55
factors = factors_k55
else:
raise NotImplementedError
for i in range(len(wbits)):
for j in range(len(abits)):
mix_scale += sw[i] * sa[j] / dsp_factors[wbits[i]-2][abits[j]-2]
mix_scale += sw[i] * sa[j] / factors[wbits[i]-2][abits[j]-2]
complexity = self.size_product.item() * 64 * mix_scale
return complexity
@ -493,21 +465,21 @@ class MixActivConv2d(nn.Module):
bitw = self.param_size * wbits[best_weight]
if self.kernel_size == 1:
dsp_factors = dsp_factors_k11
factors = factors_k11
elif self.kernel_size == 3:
dsp_factors = dsp_factors_k33
factors = factors_k33
elif self.kernel_size == 5:
dsp_factors = dsp_factors_k55
factors = factors_k55
else:
raise NotImplementedError
dsps = size_product / dsp_factors[wbits[best_weight]-2][abits[best_activ]-2]
dsps = size_product / factors[wbits[best_weight]-2][abits[best_activ]-2]
mixbitops = size_product * mix_abit * mix_wbit
mixbita = memory_size * mix_abit
mixbitw = self.param_size * mix_wbit
mixdsps = 0
for i in range(len(wbits)):
for j in range(len(abits)):
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
mixdsps += prob_weight[i] * prob_activ[j] / factors[wbits[i]-2][abits[j]-2]
mixdsps *= size_product
mixbram_weight = self.param_size * 1e3 * mix_wbit # kbit
mixbram_cache = bram_sw * mix_abit # kbit
@ -595,7 +567,7 @@ class MixActivLinear(nn.Module):
wbits = self.mix_weight.bits
for i in range(len(wbits)):
for j in range(len(abits)):
mix_scale += sw[i] * sa[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
mix_scale += sw[i] * sa[j] / factors_k11[wbits[i]-2][abits[j]-2]
complexity = self.size_product.item() * 64 * mix_scale
return complexity
@ -627,13 +599,13 @@ class MixActivLinear(nn.Module):
bitops = size_product * abits[best_activ] * wbits[best_weight]
bita = memory_size * abits[best_activ]
bitw = self.param_size * wbits[best_weight]
dsps = size_product / dsp_factors_k11[wbits[best_weight]-2][abits[best_activ]-2]
dsps = size_product / factors_k11[wbits[best_weight]-2][abits[best_activ]-2]
mixbitops = size_product * mix_abit * mix_wbit
mixbita = memory_size * mix_abit
mixbitw = self.param_size * mix_wbit
mixdsps = 0
for i in range(len(wbits)):
for j in range(len(abits)):
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
mixdsps += prob_weight[i] * prob_activ[j] / factors_k11[wbits[i]-2][abits[j]-2]
mixdsps *= size_product
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps

View File

@ -105,22 +105,36 @@ def train():
print('Finished Training')
with open('results.csv', 'a') as f:
print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "%
(opt.name,epochs-1,epochs,opt.bitw,opt.bita,macc*100,(test_acc+test_best_acc)/2,
int(round(bops)), dsps), file=f)
# torch.save(net.state_dict(), 'lenet_cifar10.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--name', default='VGG_tiny_FixQ', help='result and weight file name')
parser.add_argument('-w', '--weights', default=None, help='weights path')
parser.add_argument('-e', '--epochs', type=int, default=40)
parser.add_argument('-e', '--epochs', type=int, default=200)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--bypass', action='store_true', help='use bypass model')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--lr', type=float, default=0.03)
parser.add_argument('--mixm', type=str)
parser.add_argument('--bitw', type=str, default='')
parser.add_argument('--bita', type=str, default='')
opt = parser.parse_args()
if opt.mixm is not None:
wmix = torch.load('weights/%s.pt'%opt.mixm)
opt.bitw = wmix['extra']['bestw']
opt.bita = wmix['extra']['besta']
del wmix
print(opt)
wdir = 'weights' + os.sep # weights dir
last = wdir + '%s_last.pt'%opt.name

View File

@ -125,22 +125,22 @@ class VGG_tiny_MixQ(nn.Module):
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, **qspace), # 1
self.pooling,
nn.BatchNorm2d(64),
self.pooling,
conv_func(64, 128, **conv_kwargs, **qspace), # 2
nn.BatchNorm2d(128),
conv_func(128, 128, **conv_kwargs, **qspace), # 3
self.pooling,
nn.BatchNorm2d(128),
self.pooling,
conv_func(128, 256, **conv_kwargs, **qspace), # 4
nn.BatchNorm2d(256),
conv_func(256, 256, **conv_kwargs, **qspace), # 5
self.pooling,
nn.BatchNorm2d(256),
self.pooling,
nn.Flatten(),
qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8)
@ -156,7 +156,7 @@ class VGG_tiny_MixQ(nn.Module):
best_arch = None
for m in self.modules():
if isinstance(m, self.conv_func):
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx)
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = m.fetch_best_arch(layer_idx)
if best_arch is None:
best_arch = layer_arch
else:
@ -186,6 +186,17 @@ class VGG_tiny_MixQ(nn.Module):
normalizer = size_product[0].item()
loss /= normalizer
return loss
def complexity_loss_trivial(self):
size_product = []
loss = 0
for m in self.modules():
if isinstance(m, self.conv_func):
loss += m.complexity_loss_trivial()
size_product += [m.size_product]
normalizer = size_product[0].item()
loss /= normalizer
return loss
class VGG_tiny_FixQ(nn.Module):
def __init__(self, num_classes=10, bitw = '444444', bita = '844444'):
@ -211,22 +222,22 @@ class VGG_tiny_FixQ(nn.Module):
nn.BatchNorm2d(64),
conv_func(64, 64, **conv_kwargs, wbit=bitw[1], abit=bita[1]), # 1
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2),
conv_func(64, 128, **conv_kwargs, wbit=bitw[2], abit=bita[2]), # 2
nn.BatchNorm2d(128),
conv_func(128, 128, **conv_kwargs, wbit=bitw[3], abit=bita[3]), # 3
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(128),
nn.MaxPool2d(kernel_size=2, stride=2),
conv_func(128, 256, **conv_kwargs, wbit=bitw[4], abit=bita[4]), # 4
nn.BatchNorm2d(256),
conv_func(256, 256, **conv_kwargs, wbit=bitw[5], abit=bita[5]), # 5
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(256),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8)
@ -245,7 +256,7 @@ class VGG_tiny_FixQ(nn.Module):
bitops = size_product * m.abit * m.wbit
bita = m.memory_size.item() * m.abit
bitw = m.param_size * m.wbit
dsps = size_product / qm.dsp_factors[m.wbit-2][m.abit-2]
dsps = size_product / qm.dsp_factors_k33[m.wbit-2][m.abit-2]
weight_shape = list(m.conv.weight.shape)
print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, '
'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit,

View File

@ -75,8 +75,9 @@ def train():
correct = (predicted == labels).sum().item()
loss = criterion(outputs, labels)
if opt.complexity_decay != 0:
loss_complexity = opt.complexity_decay * model.complexity_loss()
if opt.complexity_decay != 0 or opt.complexity_decay_trivial!=0:
loss_complexity = opt.complexity_decay * model.complexity_loss() + \
opt.complexity_decay_trivial * model.complexity_loss_trivial()
loss += loss_complexity
loss.backward()
@ -94,6 +95,8 @@ def train():
bitops, bita, bitw, dsps))
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
mixbitops, mixbita, mixbitw, mixdsps))
bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]])
besta_str = "".join([str(x+2) for x in best_arch["best_activ"]])
print(f'best_weight: {best_arch["best_weight"]}')
print(f'best_activ: {best_arch["best_activ"]}')
@ -115,7 +118,7 @@ def train():
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': None if final_epoch else optimizer.state_dict(),
'arch_optimizer': None if final_epoch else arch_optimizer.state_dict(),
'extra': {'time': time.ctime(), 'name': opt.name}}
'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}}
# Save last checkpoint
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
@ -124,6 +127,12 @@ def train():
print('Finished Training')
with open('results.csv', 'a') as f:
print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"%
(opt.name,epochs-1,epochs,macc*100,(test_acc+test_best_acc)/2,
bestw_str,besta_str,
int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f)
# torch.save(net.state_dict(), 'lenet_cifar10.pth')
if __name__ == '__main__':
@ -135,6 +144,7 @@ if __name__ == '__main__':
parser.add_argument('--name', default='', help='result and weight file name')
parser.add_argument('--noshare', action='store_true', help='no share weight')
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay w/o hardware-aware')
parser.add_argument('--lra', '--learning-rate-alpha', default=0.1, type=float, metavar='LR', help='initial alpha learning rate')
opt = parser.parse_args()

View File

@ -244,6 +244,8 @@ def train():
# Update scheduler
scheduler.step()
train_iou = mloss[2]
# Process epoch results
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
@ -293,6 +295,11 @@ def train():
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
torch.cuda.empty_cache()
with open('results.csv', 'a') as f:
print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "%
(opt.name,epochs-1,epochs,opt.bitw,opt.bita,train_iou*100,(test_iou+test_best_iou)*50,
int(round(bops)), dsps), file=f)
return results
@ -317,12 +324,19 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--mixm', type=str)
parser.add_argument('--bitw', type=str, default='')
parser.add_argument('--bita', type=str, default='')
parser.add_argument('--var', type=float, help='debug variable')
parser.add_argument('--model', type=str, default='', help='use specific model')
opt = parser.parse_args()
if opt.mixm is not None:
wmix = torch.load('weights/%s.pt'%opt.mixm)
opt.bitw = wmix['extra']['bestw']
opt.bita = wmix['extra']['besta']
del wmix
last = wdir + 'last_%s.pt'%opt.name
opt.weights = last if opt.resume else opt.weights
print(opt)

View File

@ -152,6 +152,17 @@ def mixq_bram_loss(self):
loss /= normalizer
return loss
def mixq_complexity_loss_trivial(self):
size_product = []
loss = 0
for m in self.modules():
if isinstance(m, self.conv_func):
loss += m.complexity_loss_trivial()
size_product += [m.size_product]
normalizer = size_product[0].item()
loss /= normalizer
return loss
class UltraNet_ismart(nn.Module):
def __init__(self):
super(UltraNet_ismart, self).__init__()
@ -367,7 +378,8 @@ class UltraNet_MixQ(nn.Module):
fetch_best_arch = mixq_fetch_best_arch
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
complexity_loss = mixq_complexity_loss
complexity_loss_trivial = mixq_complexity_loss_trivial
bram_loss = mixq_bram_loss
class UltraNet_FixQ(nn.Module):
@ -685,7 +697,8 @@ class UltraNetBypass_MixQ(nn.Module):
fetch_best_arch = mixq_fetch_best_arch
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
complexity_loss = mixq_complexity_loss
mixq_complexity_loss_trivial = mixq_complexity_loss_trivial
bram_loss = mixq_bram_loss
class UltraNetBypass_FixQ(nn.Module):
@ -876,7 +889,8 @@ class SkyNet_MixQ(nn.Module):
fetch_best_arch = mixq_fetch_best_arch
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
complexity_loss = mixq_complexity_loss
mixq_complexity_loss_trivial = mixq_complexity_loss_trivial
bram_loss = mixq_bram_loss
class SkyNetk5_MixQ(nn.Module):

30
dacsdc/pareto_train.py Normal file
View File

@ -0,0 +1,30 @@
import os
import argparse
cds = {
'cd':['3e-5', '6e-5', '1e-4', '2e-4', '3e-4'],
'cdt':['1e-5', '2e-5', '3e-5', '6e-5', '1e-4'],
}
def search_train():
for cd in cds[opt.arg]:
name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','')
os.system('python search_train.py --name %s --cd %s'%('f'+name, cd))
def main_train():
for cd in cds[opt.arg]:
name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','')
os.system('python main_train.py --name %s --mixm %s'%('x'+name, 'f'+name+'_last'))
parser = argparse.ArgumentParser()
parser.add_argument('--search', action='store_true')
parser.add_argument('--main', action='store_true')
parser.add_argument('--it', type=int)
parser.add_argument('--arg', type=str)
opt = parser.parse_args()
if opt.search:
search_train()
if opt.main:
main_train()

View File

@ -239,16 +239,15 @@ def train():
# complexity penalty
if opt.complexity_decay != 0:
if hasattr(model, 'module'):
loss_complexity = opt.complexity_decay * model.module.complexity_loss()
else:
loss_complexity = opt.complexity_decay * model.complexity_loss()
loss_complexity = opt.complexity_decay * model.complexity_loss()
loss += loss_complexity * 4.0
if opt.complexity_decay_trivial != 0:
loss_complexity_trivial = opt.complexity_decay_trivial * model.complexity_loss_trivial()
loss += loss_complexity_trivial * 4.0
if opt.bram_decay != 0:
if hasattr(model, 'module'):
loss_bram = opt.bram_decay * model.module.bram_loss()
else:
loss_bram = opt.bram_decay * model.bram_loss()
loss += loss_bram * 4.0
@ -278,13 +277,18 @@ def train():
bitops, bita, bitw, dsps))
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram_wa:({:.3f},{:.3f})K'.format(
mixbitops, mixbita, mixbitw, mixdsps, mixbram_weight, mixbram_cache))
print(f'best_weight: {best_arch["best_weight"]} ({"".join([str(x+2) for x in best_arch["best_weight"]])})')
print(f'best_activ: {best_arch["best_activ"]} ({"".join([str(x+2) for x in best_arch["best_activ"]])})')
bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]])
besta_str = "".join([str(x+2) for x in best_arch["best_activ"]])
print(f'best_weight: {best_arch["best_weight"]}')
print(f'best_activ: {best_arch["best_activ"]}')
# Update scheduler
scheduler.step()
arch_scheduler.step()
train_iou = mloss[2]
# Process epoch results
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
@ -314,7 +318,7 @@ def train():
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': None if final_epoch else optimizer.state_dict(),
'extra': {'time': time.ctime(), 'name': opt.name}}
'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}}
# Save last checkpoint
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
@ -333,6 +337,12 @@ def train():
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
torch.cuda.empty_cache()
with open('results.csv', 'a') as f:
print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"%
(opt.name,epochs-1,epochs,train_iou*100,(test_iou+test_best_iou)*50,
bestw_str,besta_str,
int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f)
return results
@ -359,6 +369,7 @@ if __name__ == '__main__':
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--var', type=float, help='debug variable')
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--bram-decay', '--bd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate')
parser.add_argument('--no-share', action='store_true', help='no share weight quantization')