forked from liucheng/DeepBurning-MixQ
add trivial_loss exp; add pareto_train exp; fix vgg_tiny pooling order; seperate packing_factor
This commit is contained in:
parent
0b983a1223
commit
587637c4fc
|
@ -0,0 +1,29 @@
|
|||
factors_k11=[
|
||||
[12,8,8,6,6,4,4],
|
||||
[10,8,6,6,4,4,4],
|
||||
[8,6,6,4,4,4,3],
|
||||
[6,6,4,4,4,4,2],
|
||||
[6,4,4,4,2,2,2],
|
||||
[4,4,4,4,2,2,2],
|
||||
[4,4,3,2,2,2,2],
|
||||
]
|
||||
|
||||
factors_k33=[
|
||||
[18,15,12,7.5,7.5,6,6],
|
||||
[15,12,7.5,6,6,6,3],
|
||||
[12,7.5,6,6,6,6,3],
|
||||
[9,6,6,6,6,3,3],
|
||||
[7.5,6,6,4.5,3,3,3],
|
||||
[6,6,4.5,3,3,3,2.25],
|
||||
[6,3,3,3,3,3,2],
|
||||
]
|
||||
|
||||
factors_k55=[
|
||||
[20,15,10,7.5,7.5,5,5],
|
||||
[12.5,10,6.67,5,5,5,3.33],
|
||||
[10,7.5,5,5,5,5,3.33],
|
||||
[7.5,6.67,5,5,5,3.33,3.33],
|
||||
[6.67,5,5,5,3.33,2.5,2.5],
|
||||
[5,5,5,3.33,2.5,2.5,2.5],
|
||||
[5,3.33,3.33,3.33,2.5,2.5,2],
|
||||
]
|
|
@ -5,40 +5,12 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
# load dsp packing factors
|
||||
from .dsp_packing import *
|
||||
|
||||
gaussian_steps = {1: 1.596, 2: 0.996, 3: 0.586, 4: 0.336, 5: 0.190, 6: 0.106, 7: 0.059, 8: 0.032}
|
||||
hwgq_steps = {1: 0.799, 2: 0.538, 3: 0.3217, 4: 0.185, 5: 0.104, 6: 0.058, 7: 0.033, 8: 0.019}
|
||||
|
||||
|
||||
dsp_factors_k11=[
|
||||
[12,8,8,6,6,4,4],
|
||||
[10,8,6,6,4,4,4],
|
||||
[8,6,6,4,4,4,3],
|
||||
[6,6,4,4,4,4,2],
|
||||
[6,4,4,4,2,2,2],
|
||||
[4,4,4,4,2,2,2],
|
||||
[4,4,3,2,2,2,2],
|
||||
]
|
||||
|
||||
dsp_factors_k33=[
|
||||
[18,15,12,7.5,7.5,6,6],
|
||||
[15,12,7.5,6,6,6,3],
|
||||
[12,7.5,6,6,6,6,3],
|
||||
[9,6,6,6,6,3,3],
|
||||
[7.5,6,6,4.5,3,3,3],
|
||||
[6,6,4.5,3,3,3,2.25],
|
||||
[6,3,3,3,3,3,2],
|
||||
]
|
||||
|
||||
dsp_factors_k55=[
|
||||
[20,15,10,7.5,7.5,5,5],
|
||||
[12.5,10,6.67,5,5,5,3.33],
|
||||
[10,7.5,5,5,5,5,3.33],
|
||||
[7.5,6.67,5,5,5,3.33,3.33],
|
||||
[6.67,5,5,5,3.33,2.5,2.5],
|
||||
[5,5,5,3.33,2.5,2.5,2.5],
|
||||
[5,3.33,3.33,3.33,2.5,2.5,2],
|
||||
]
|
||||
|
||||
class _gauss_quantize_sym(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -394,7 +366,7 @@ class MixActivConv2d(nn.Module):
|
|||
out = self.mix_weight(out)
|
||||
return out
|
||||
|
||||
def complexity_loss_old(self):
|
||||
def complexity_loss_trivial(self):
|
||||
sw = F.softmax(self.mix_activ.alpha_activ, dim=0)
|
||||
mix_abit = 0
|
||||
abits = self.mix_activ.bits
|
||||
|
@ -416,16 +388,16 @@ class MixActivConv2d(nn.Module):
|
|||
wbits = self.mix_weight.bits
|
||||
|
||||
if self.kernel_size == 1:
|
||||
dsp_factors = dsp_factors_k11
|
||||
factors = factors_k11
|
||||
elif self.kernel_size == 3:
|
||||
dsp_factors = dsp_factors_k33
|
||||
factors = factors_k33
|
||||
elif self.kernel_size == 5:
|
||||
dsp_factors = dsp_factors_k55
|
||||
factors = factors_k55
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for i in range(len(wbits)):
|
||||
for j in range(len(abits)):
|
||||
mix_scale += sw[i] * sa[j] / dsp_factors[wbits[i]-2][abits[j]-2]
|
||||
mix_scale += sw[i] * sa[j] / factors[wbits[i]-2][abits[j]-2]
|
||||
complexity = self.size_product.item() * 64 * mix_scale
|
||||
return complexity
|
||||
|
||||
|
@ -493,21 +465,21 @@ class MixActivConv2d(nn.Module):
|
|||
bitw = self.param_size * wbits[best_weight]
|
||||
|
||||
if self.kernel_size == 1:
|
||||
dsp_factors = dsp_factors_k11
|
||||
factors = factors_k11
|
||||
elif self.kernel_size == 3:
|
||||
dsp_factors = dsp_factors_k33
|
||||
factors = factors_k33
|
||||
elif self.kernel_size == 5:
|
||||
dsp_factors = dsp_factors_k55
|
||||
factors = factors_k55
|
||||
else:
|
||||
raise NotImplementedError
|
||||
dsps = size_product / dsp_factors[wbits[best_weight]-2][abits[best_activ]-2]
|
||||
dsps = size_product / factors[wbits[best_weight]-2][abits[best_activ]-2]
|
||||
mixbitops = size_product * mix_abit * mix_wbit
|
||||
mixbita = memory_size * mix_abit
|
||||
mixbitw = self.param_size * mix_wbit
|
||||
mixdsps = 0
|
||||
for i in range(len(wbits)):
|
||||
for j in range(len(abits)):
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / factors[wbits[i]-2][abits[j]-2]
|
||||
mixdsps *= size_product
|
||||
mixbram_weight = self.param_size * 1e3 * mix_wbit # kbit
|
||||
mixbram_cache = bram_sw * mix_abit # kbit
|
||||
|
@ -595,7 +567,7 @@ class MixActivLinear(nn.Module):
|
|||
wbits = self.mix_weight.bits
|
||||
for i in range(len(wbits)):
|
||||
for j in range(len(abits)):
|
||||
mix_scale += sw[i] * sa[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
|
||||
mix_scale += sw[i] * sa[j] / factors_k11[wbits[i]-2][abits[j]-2]
|
||||
complexity = self.size_product.item() * 64 * mix_scale
|
||||
return complexity
|
||||
|
||||
|
@ -627,13 +599,13 @@ class MixActivLinear(nn.Module):
|
|||
bitops = size_product * abits[best_activ] * wbits[best_weight]
|
||||
bita = memory_size * abits[best_activ]
|
||||
bitw = self.param_size * wbits[best_weight]
|
||||
dsps = size_product / dsp_factors_k11[wbits[best_weight]-2][abits[best_activ]-2]
|
||||
dsps = size_product / factors_k11[wbits[best_weight]-2][abits[best_activ]-2]
|
||||
mixbitops = size_product * mix_abit * mix_wbit
|
||||
mixbita = memory_size * mix_abit
|
||||
mixbitw = self.param_size * mix_wbit
|
||||
mixdsps = 0
|
||||
for i in range(len(wbits)):
|
||||
for j in range(len(abits)):
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / factors_k11[wbits[i]-2][abits[j]-2]
|
||||
mixdsps *= size_product
|
||||
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps
|
||||
|
|
|
@ -105,22 +105,36 @@ def train():
|
|||
|
||||
print('Finished Training')
|
||||
|
||||
with open('results.csv', 'a') as f:
|
||||
print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "%
|
||||
(opt.name,epochs-1,epochs,opt.bitw,opt.bita,macc*100,(test_acc+test_best_acc)/2,
|
||||
int(round(bops)), dsps), file=f)
|
||||
|
||||
# torch.save(net.state_dict(), 'lenet_cifar10.pth')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-n', '--name', default='VGG_tiny_FixQ', help='result and weight file name')
|
||||
parser.add_argument('-w', '--weights', default=None, help='weights path')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=40)
|
||||
parser.add_argument('-e', '--epochs', type=int, default=200)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--bypass', action='store_true', help='use bypass model')
|
||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
parser.add_argument('--lr', type=float, default=0.03)
|
||||
parser.add_argument('--mixm', type=str)
|
||||
parser.add_argument('--bitw', type=str, default='')
|
||||
parser.add_argument('--bita', type=str, default='')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.mixm is not None:
|
||||
wmix = torch.load('weights/%s.pt'%opt.mixm)
|
||||
opt.bitw = wmix['extra']['bestw']
|
||||
opt.bita = wmix['extra']['besta']
|
||||
del wmix
|
||||
|
||||
print(opt)
|
||||
|
||||
wdir = 'weights' + os.sep # weights dir
|
||||
last = wdir + '%s_last.pt'%opt.name
|
||||
|
||||
|
|
|
@ -125,22 +125,22 @@ class VGG_tiny_MixQ(nn.Module):
|
|||
nn.BatchNorm2d(64),
|
||||
|
||||
conv_func(64, 64, **conv_kwargs, **qspace), # 1
|
||||
self.pooling,
|
||||
nn.BatchNorm2d(64),
|
||||
self.pooling,
|
||||
|
||||
conv_func(64, 128, **conv_kwargs, **qspace), # 2
|
||||
nn.BatchNorm2d(128),
|
||||
|
||||
conv_func(128, 128, **conv_kwargs, **qspace), # 3
|
||||
self.pooling,
|
||||
nn.BatchNorm2d(128),
|
||||
self.pooling,
|
||||
|
||||
conv_func(128, 256, **conv_kwargs, **qspace), # 4
|
||||
nn.BatchNorm2d(256),
|
||||
|
||||
conv_func(256, 256, **conv_kwargs, **qspace), # 5
|
||||
self.pooling,
|
||||
nn.BatchNorm2d(256),
|
||||
self.pooling,
|
||||
|
||||
nn.Flatten(),
|
||||
qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8)
|
||||
|
@ -156,7 +156,7 @@ class VGG_tiny_MixQ(nn.Module):
|
|||
best_arch = None
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx)
|
||||
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = m.fetch_best_arch(layer_idx)
|
||||
if best_arch is None:
|
||||
best_arch = layer_arch
|
||||
else:
|
||||
|
@ -186,6 +186,17 @@ class VGG_tiny_MixQ(nn.Module):
|
|||
normalizer = size_product[0].item()
|
||||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
def complexity_loss_trivial(self):
|
||||
size_product = []
|
||||
loss = 0
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
loss += m.complexity_loss_trivial()
|
||||
size_product += [m.size_product]
|
||||
normalizer = size_product[0].item()
|
||||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
class VGG_tiny_FixQ(nn.Module):
|
||||
def __init__(self, num_classes=10, bitw = '444444', bita = '844444'):
|
||||
|
@ -211,22 +222,22 @@ class VGG_tiny_FixQ(nn.Module):
|
|||
nn.BatchNorm2d(64),
|
||||
|
||||
conv_func(64, 64, **conv_kwargs, wbit=bitw[1], abit=bita[1]), # 1
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
conv_func(64, 128, **conv_kwargs, wbit=bitw[2], abit=bita[2]), # 2
|
||||
nn.BatchNorm2d(128),
|
||||
|
||||
conv_func(128, 128, **conv_kwargs, wbit=bitw[3], abit=bita[3]), # 3
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
conv_func(128, 256, **conv_kwargs, wbit=bitw[4], abit=bita[4]), # 4
|
||||
nn.BatchNorm2d(256),
|
||||
|
||||
conv_func(256, 256, **conv_kwargs, wbit=bitw[5], abit=bita[5]), # 5
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Flatten(),
|
||||
qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8)
|
||||
|
@ -245,7 +256,7 @@ class VGG_tiny_FixQ(nn.Module):
|
|||
bitops = size_product * m.abit * m.wbit
|
||||
bita = m.memory_size.item() * m.abit
|
||||
bitw = m.param_size * m.wbit
|
||||
dsps = size_product / qm.dsp_factors[m.wbit-2][m.abit-2]
|
||||
dsps = size_product / qm.dsp_factors_k33[m.wbit-2][m.abit-2]
|
||||
weight_shape = list(m.conv.weight.shape)
|
||||
print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, '
|
||||
'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit,
|
||||
|
|
|
@ -75,8 +75,9 @@ def train():
|
|||
correct = (predicted == labels).sum().item()
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
if opt.complexity_decay != 0:
|
||||
loss_complexity = opt.complexity_decay * model.complexity_loss()
|
||||
if opt.complexity_decay != 0 or opt.complexity_decay_trivial!=0:
|
||||
loss_complexity = opt.complexity_decay * model.complexity_loss() + \
|
||||
opt.complexity_decay_trivial * model.complexity_loss_trivial()
|
||||
loss += loss_complexity
|
||||
|
||||
loss.backward()
|
||||
|
@ -94,6 +95,8 @@ def train():
|
|||
bitops, bita, bitw, dsps))
|
||||
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
|
||||
mixbitops, mixbita, mixbitw, mixdsps))
|
||||
bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]])
|
||||
besta_str = "".join([str(x+2) for x in best_arch["best_activ"]])
|
||||
print(f'best_weight: {best_arch["best_weight"]}')
|
||||
print(f'best_activ: {best_arch["best_activ"]}')
|
||||
|
||||
|
@ -115,7 +118,7 @@ def train():
|
|||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||
'optimizer': None if final_epoch else optimizer.state_dict(),
|
||||
'arch_optimizer': None if final_epoch else arch_optimizer.state_dict(),
|
||||
'extra': {'time': time.ctime(), 'name': opt.name}}
|
||||
'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}}
|
||||
# Save last checkpoint
|
||||
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
|
||||
|
||||
|
@ -124,6 +127,12 @@ def train():
|
|||
|
||||
print('Finished Training')
|
||||
|
||||
with open('results.csv', 'a') as f:
|
||||
print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"%
|
||||
(opt.name,epochs-1,epochs,macc*100,(test_acc+test_best_acc)/2,
|
||||
bestw_str,besta_str,
|
||||
int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f)
|
||||
|
||||
# torch.save(net.state_dict(), 'lenet_cifar10.pth')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -135,6 +144,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--name', default='', help='result and weight file name')
|
||||
parser.add_argument('--noshare', action='store_true', help='no share weight')
|
||||
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay w/o hardware-aware')
|
||||
parser.add_argument('--lra', '--learning-rate-alpha', default=0.1, type=float, metavar='LR', help='initial alpha learning rate')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
|
|
@ -244,6 +244,8 @@ def train():
|
|||
# Update scheduler
|
||||
scheduler.step()
|
||||
|
||||
train_iou = mloss[2]
|
||||
|
||||
# Process epoch results
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
|
@ -293,6 +295,11 @@ def train():
|
|||
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with open('results.csv', 'a') as f:
|
||||
print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "%
|
||||
(opt.name,epochs-1,epochs,opt.bitw,opt.bita,train_iou*100,(test_iou+test_best_iou)*50,
|
||||
int(round(bops)), dsps), file=f)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -317,12 +324,19 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
|
||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument('--mixm', type=str)
|
||||
parser.add_argument('--bitw', type=str, default='')
|
||||
parser.add_argument('--bita', type=str, default='')
|
||||
parser.add_argument('--var', type=float, help='debug variable')
|
||||
parser.add_argument('--model', type=str, default='', help='use specific model')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.mixm is not None:
|
||||
wmix = torch.load('weights/%s.pt'%opt.mixm)
|
||||
opt.bitw = wmix['extra']['bestw']
|
||||
opt.bita = wmix['extra']['besta']
|
||||
del wmix
|
||||
last = wdir + 'last_%s.pt'%opt.name
|
||||
opt.weights = last if opt.resume else opt.weights
|
||||
print(opt)
|
||||
|
|
|
@ -152,6 +152,17 @@ def mixq_bram_loss(self):
|
|||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
def mixq_complexity_loss_trivial(self):
|
||||
size_product = []
|
||||
loss = 0
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
loss += m.complexity_loss_trivial()
|
||||
size_product += [m.size_product]
|
||||
normalizer = size_product[0].item()
|
||||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
class UltraNet_ismart(nn.Module):
|
||||
def __init__(self):
|
||||
super(UltraNet_ismart, self).__init__()
|
||||
|
@ -367,7 +378,8 @@ class UltraNet_MixQ(nn.Module):
|
|||
fetch_best_arch = mixq_fetch_best_arch
|
||||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
complexity_loss = mixq_complexity_loss
|
||||
complexity_loss_trivial = mixq_complexity_loss_trivial
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class UltraNet_FixQ(nn.Module):
|
||||
|
@ -685,7 +697,8 @@ class UltraNetBypass_MixQ(nn.Module):
|
|||
fetch_best_arch = mixq_fetch_best_arch
|
||||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
complexity_loss = mixq_complexity_loss
|
||||
mixq_complexity_loss_trivial = mixq_complexity_loss_trivial
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class UltraNetBypass_FixQ(nn.Module):
|
||||
|
@ -876,7 +889,8 @@ class SkyNet_MixQ(nn.Module):
|
|||
fetch_best_arch = mixq_fetch_best_arch
|
||||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
complexity_loss = mixq_complexity_loss
|
||||
mixq_complexity_loss_trivial = mixq_complexity_loss_trivial
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class SkyNetk5_MixQ(nn.Module):
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
cds = {
|
||||
'cd':['3e-5', '6e-5', '1e-4', '2e-4', '3e-4'],
|
||||
'cdt':['1e-5', '2e-5', '3e-5', '6e-5', '1e-4'],
|
||||
}
|
||||
|
||||
def search_train():
|
||||
for cd in cds[opt.arg]:
|
||||
name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','')
|
||||
os.system('python search_train.py --name %s --cd %s'%('f'+name, cd))
|
||||
|
||||
def main_train():
|
||||
for cd in cds[opt.arg]:
|
||||
name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','')
|
||||
os.system('python main_train.py --name %s --mixm %s'%('x'+name, 'f'+name+'_last'))
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--search', action='store_true')
|
||||
parser.add_argument('--main', action='store_true')
|
||||
parser.add_argument('--it', type=int)
|
||||
parser.add_argument('--arg', type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.search:
|
||||
search_train()
|
||||
|
||||
if opt.main:
|
||||
main_train()
|
|
@ -239,16 +239,15 @@ def train():
|
|||
|
||||
# complexity penalty
|
||||
if opt.complexity_decay != 0:
|
||||
if hasattr(model, 'module'):
|
||||
loss_complexity = opt.complexity_decay * model.module.complexity_loss()
|
||||
else:
|
||||
loss_complexity = opt.complexity_decay * model.complexity_loss()
|
||||
loss_complexity = opt.complexity_decay * model.complexity_loss()
|
||||
loss += loss_complexity * 4.0
|
||||
|
||||
if opt.complexity_decay_trivial != 0:
|
||||
loss_complexity_trivial = opt.complexity_decay_trivial * model.complexity_loss_trivial()
|
||||
loss += loss_complexity_trivial * 4.0
|
||||
|
||||
if opt.bram_decay != 0:
|
||||
if hasattr(model, 'module'):
|
||||
loss_bram = opt.bram_decay * model.module.bram_loss()
|
||||
else:
|
||||
loss_bram = opt.bram_decay * model.bram_loss()
|
||||
loss += loss_bram * 4.0
|
||||
|
||||
|
@ -278,13 +277,18 @@ def train():
|
|||
bitops, bita, bitw, dsps))
|
||||
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram_wa:({:.3f},{:.3f})K'.format(
|
||||
mixbitops, mixbita, mixbitw, mixdsps, mixbram_weight, mixbram_cache))
|
||||
print(f'best_weight: {best_arch["best_weight"]} ({"".join([str(x+2) for x in best_arch["best_weight"]])})')
|
||||
print(f'best_activ: {best_arch["best_activ"]} ({"".join([str(x+2) for x in best_arch["best_activ"]])})')
|
||||
|
||||
|
||||
bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]])
|
||||
besta_str = "".join([str(x+2) for x in best_arch["best_activ"]])
|
||||
print(f'best_weight: {best_arch["best_weight"]}')
|
||||
print(f'best_activ: {best_arch["best_activ"]}')
|
||||
|
||||
# Update scheduler
|
||||
scheduler.step()
|
||||
arch_scheduler.step()
|
||||
|
||||
train_iou = mloss[2]
|
||||
|
||||
# Process epoch results
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
|
@ -314,7 +318,7 @@ def train():
|
|||
'model': model.module.state_dict() if type(
|
||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||
'optimizer': None if final_epoch else optimizer.state_dict(),
|
||||
'extra': {'time': time.ctime(), 'name': opt.name}}
|
||||
'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}}
|
||||
|
||||
# Save last checkpoint
|
||||
torch.save(chkpt, wdir + '%s_last.pt'%opt.name)
|
||||
|
@ -333,6 +337,12 @@ def train():
|
|||
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with open('results.csv', 'a') as f:
|
||||
print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"%
|
||||
(opt.name,epochs-1,epochs,train_iou*100,(test_iou+test_best_iou)*50,
|
||||
bestw_str,besta_str,
|
||||
int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -359,6 +369,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument('--var', type=float, help='debug variable')
|
||||
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--bram-decay', '--bd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate')
|
||||
parser.add_argument('--no-share', action='store_true', help='no share weight quantization')
|
||||
|
|
Loading…
Reference in New Issue