forked from liucheng/DeepBurning-MixQ
add bram loss(WIP)
This commit is contained in:
parent
ed6b0ec2fa
commit
87f943e14a
|
@ -209,10 +209,14 @@ class QuantActivConv2d(nn.Module):
|
|||
self.kernel_size = kwargs['kernel_size']
|
||||
if 'groups' in kwargs: groups = kwargs['groups']
|
||||
else: groups = 1
|
||||
self.inplane = inplane
|
||||
self.outplane = outplane
|
||||
self.groups = groups
|
||||
self.param_size = inplane * outplane * kernel_size * 1e-6 / groups
|
||||
self.filter_size = self.param_size / float(stride ** 2.0)
|
||||
self.register_buffer('size_product', torch.tensor(0, dtype=torch.float))
|
||||
self.register_buffer('memory_size', torch.tensor(0, dtype=torch.float))
|
||||
self.register_buffer('in_width', torch.tensor(0, dtype=torch.float))
|
||||
|
||||
def forward(self, input):
|
||||
in_shape = input.shape
|
||||
|
@ -221,6 +225,8 @@ class QuantActivConv2d(nn.Module):
|
|||
tmp = torch.tensor(self.filter_size * in_shape[-1] * in_shape[-2], dtype=torch.float)
|
||||
self.size_product.copy_(tmp)
|
||||
out = self.activ(input)
|
||||
tmp = torch.tensor(input.shape[3], dtype=torch.float)
|
||||
self.in_width.copy_(tmp)
|
||||
## print('ii',input[0,0,:,0]/self.activ.step)
|
||||
## print('convi', torch.round(out[0,0,:,0]/self.activ.step).int())
|
||||
## wstd = self.conv.weight.std()
|
||||
|
@ -363,10 +369,14 @@ class MixActivConv2d(nn.Module):
|
|||
|
||||
if 'groups' in kwargs: groups = kwargs['groups']
|
||||
else: groups = 1
|
||||
self.inplane = inplane
|
||||
self.outplane = outplane
|
||||
self.groups = groups
|
||||
self.param_size = inplane * outplane * kernel_size * 1e-6 / groups
|
||||
self.filter_size = self.param_size / float(stride ** 2.0)
|
||||
self.register_buffer('size_product', torch.tensor(0, dtype=torch.float))
|
||||
self.register_buffer('memory_size', torch.tensor(0, dtype=torch.float))
|
||||
self.register_buffer('in_width', torch.tensor(0, dtype=torch.float))
|
||||
|
||||
def forward(self, input):
|
||||
in_shape = input.shape
|
||||
|
@ -374,6 +384,8 @@ class MixActivConv2d(nn.Module):
|
|||
self.memory_size.copy_(tmp)
|
||||
tmp = torch.tensor(self.filter_size * in_shape[-1] * in_shape[-2], dtype=torch.float)
|
||||
self.size_product.copy_(tmp)
|
||||
tmp = torch.tensor(input.shape[3], dtype=torch.float)
|
||||
self.in_width.copy_(tmp)
|
||||
out = self.mix_activ(input)
|
||||
out = self.mix_weight(out)
|
||||
return out
|
||||
|
@ -412,7 +424,30 @@ class MixActivConv2d(nn.Module):
|
|||
mix_scale += sw[i] * sa[j] / dsp_factors[wbits[i]-2][abits[j]-2]
|
||||
complexity = self.size_product.item() * 64 * mix_scale
|
||||
return complexity
|
||||
|
||||
def bram_loss(self):
|
||||
sa = F.softmax(self.mix_activ.alpha_activ, dim=0)
|
||||
abits = self.mix_activ.bits
|
||||
sw = F.softmax(self.mix_weight.alpha_weight, dim=0)
|
||||
wbits = self.mix_weight.bits
|
||||
|
||||
if self.kernel_size == 1:
|
||||
bram_sw = 2 * self.in_width.item() * self.inplane
|
||||
else: # sliding window size
|
||||
bram_sw = (self.kernel_size+1)*self.in_width.item()*self.inplane
|
||||
bram_sw *= 1e-3
|
||||
|
||||
mix_wbit, mix_abit = 0, 0
|
||||
for i in range(len(wbits)):
|
||||
mix_wbit += sw[i] * wbits[i]
|
||||
for i in range(len(abits)):
|
||||
mix_abit += sa[i] * abits[i]
|
||||
|
||||
bram_weight = self.param_size * 1e3 * mix_wbit # kbit
|
||||
bram_cache = bram_sw * mix_abit # kbit
|
||||
|
||||
bram = (bram_weight + bram_cache) * 64
|
||||
return bram
|
||||
|
||||
def fetch_best_arch(self, layer_idx):
|
||||
size_product = float(self.size_product.cpu().numpy())
|
||||
|
@ -435,9 +470,16 @@ class MixActivConv2d(nn.Module):
|
|||
weight_shape = list(self.mix_weight.conv.weight.shape)
|
||||
else:
|
||||
weight_shape = list(self.mix_weight.conv_list[0].weight.shape)
|
||||
|
||||
if self.kernel_size == 1:
|
||||
bram_sw = 2 * self.in_width.item() * self.inplane
|
||||
else:
|
||||
bram_sw = (self.kernel_size+1)*self.in_width.item()*self.inplane*self.outplane/self.groups
|
||||
bram_sw *= 1e-3
|
||||
|
||||
print('idx {} with shape {}, activ alpha: {}, comp: {:.3f}M * {:.3f} * {:.3f}, '
|
||||
'memory: {:.3f}K * {:.3f}'.format(layer_idx, weight_shape, prob_activ, size_product,
|
||||
mix_abit, mix_wbit, memory_size, mix_abit))
|
||||
'memory: {:.3f}K * {:.3f}, cache: {:.3f}K'.format(layer_idx, weight_shape, prob_activ, size_product,
|
||||
mix_abit, mix_wbit, memory_size, mix_abit, bram_sw))
|
||||
print('idx {} with shape {}, weight alpha: {}, comp: {:.3f}M * {:.3f} * {:.3f}, '
|
||||
'param: {:.3f}M * {:.3f}'.format(layer_idx, weight_shape, prob_weight, size_product,
|
||||
mix_abit, mix_wbit, self.param_size, mix_wbit))
|
||||
|
@ -463,8 +505,10 @@ class MixActivConv2d(nn.Module):
|
|||
for j in range(len(abits)):
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
|
||||
mixdsps *= size_product
|
||||
mixbram_weight = self.param_size * 1e3 * mix_wbit # kbit
|
||||
mixbram_cache = bram_sw * mix_abit # kbit
|
||||
|
||||
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps
|
||||
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache
|
||||
|
||||
|
||||
class SharedMixQuantLinear(nn.Module):
|
||||
|
@ -586,6 +630,6 @@ class MixActivLinear(nn.Module):
|
|||
mixdsps = 0
|
||||
for i in range(len(wbits)):
|
||||
for j in range(len(abits)):
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
|
||||
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
|
||||
mixdsps *= size_product
|
||||
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps
|
||||
|
|
|
@ -152,6 +152,13 @@ def train():
|
|||
num_workers=0,
|
||||
pin_memory=True,
|
||||
collate_fn=testset.collate_fn)
|
||||
|
||||
test.test(batch_size=batch_size,
|
||||
img_size=img_size_test,
|
||||
model=model,
|
||||
dataloader=testloader) # make forward
|
||||
bops, bita, bitw, dsps, brams = model.fetch_arch_info()
|
||||
print('model with bops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram: {:.3f}K'.format(bops, bita, bitw, dsps, brams))
|
||||
|
||||
# Dataset
|
||||
dataset = LoadImagesAndLabels(train_path, img_size, batch_size,
|
||||
|
@ -181,13 +188,6 @@ def train():
|
|||
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
|
||||
print('Using %g dataloader workers' % nw)
|
||||
print('Starting training for %g epochs...' % epochs)
|
||||
|
||||
test.test(batch_size=batch_size,
|
||||
img_size=img_size_test,
|
||||
model=model,
|
||||
dataloader=testloader) # make forward
|
||||
bops, bita, bitw, dsps = model.fetch_arch_info()
|
||||
print('model with bops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(bops, bita, bitw, dsps))
|
||||
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
model.train()
|
||||
|
|
|
@ -58,7 +58,7 @@ class YOLOLayer(nn.Module):
|
|||
return io.view(bs, -1, self.no), p
|
||||
|
||||
def fixq_fetch_arch_info(self):
|
||||
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
|
||||
sum_bitops, sum_bita, sum_bitw, sum_dsps, sum_bram = 0, 0, 0, 0, 0
|
||||
layer_idx = 0
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
|
@ -79,27 +79,36 @@ def fixq_fetch_arch_info(self):
|
|||
dsp_factors = qm.dsp_factors_k55
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if m.kernel_size == 1:
|
||||
bram_sw = 2 * m.in_width.item() * m.inplane
|
||||
else:
|
||||
bram_sw = (m.kernel_size+1)*m.in_width.item()*m.inplane
|
||||
bram_sw *= 1e-3
|
||||
|
||||
dsps = size_product / dsp_factors[m.wbit-2][m.abit-2]
|
||||
weight_shape = list(m.conv.weight.shape)
|
||||
print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, '
|
||||
'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit,
|
||||
m.wbit, memory_size, m.abit, m.param_size, m.wbit, dsps))
|
||||
'param: {:.3f}M * {}, dsps: {:.3f}M, bram(wa|waf):({:.2f},{:.2f}|{:.1f},{:.1f},{:.1f})K'.format(layer_idx, weight_shape, size_product, m.abit,
|
||||
m.wbit, memory_size, m.abit, m.param_size, m.wbit, dsps,
|
||||
m.param_size * 1e3, bram_sw, bitw*1e3, bram_sw*m.abit, bitw*1e3 + bram_sw*m.abit))
|
||||
sum_bitops += bitops
|
||||
sum_bita += bita
|
||||
sum_bitw += bitw
|
||||
sum_dsps += dsps
|
||||
sum_bram += bitw*1e3 + bram_sw*m.abit
|
||||
layer_idx += 1
|
||||
return sum_bitops, sum_bita, sum_bitw, sum_dsps
|
||||
return sum_bitops, sum_bita, sum_bitw, sum_dsps, sum_bram
|
||||
|
||||
def mixq_fetch_best_arch(self):
|
||||
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
|
||||
sum_mixbitops, sum_mixbita, sum_mixbitw, sum_mixdsps = 0, 0, 0, 0
|
||||
sum_mixbram_weight, sum_mixbram_cache = 0, 0
|
||||
layer_idx = 0
|
||||
best_arch = None
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx)
|
||||
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = m.fetch_best_arch(layer_idx)
|
||||
if best_arch is None:
|
||||
best_arch = layer_arch
|
||||
else:
|
||||
|
@ -116,8 +125,10 @@ def mixq_fetch_best_arch(self):
|
|||
sum_mixbitw += mixbitw
|
||||
sum_dsps += dsps
|
||||
sum_mixdsps += mixdsps
|
||||
sum_mixbram_weight += mixbram_weight
|
||||
sum_mixbram_cache += mixbram_cache
|
||||
layer_idx += 1
|
||||
return best_arch, sum_bitops, sum_bita, sum_bitw, sum_mixbitops, sum_mixbita, sum_mixbitw, sum_dsps, sum_mixdsps
|
||||
return best_arch, sum_bitops, sum_bita, sum_bitw, sum_mixbitops, sum_mixbita, sum_mixbitw, sum_dsps, sum_mixdsps, sum_mixbram_weight, sum_mixbram_cache
|
||||
|
||||
def mixq_complexity_loss(self):
|
||||
size_product = []
|
||||
|
@ -130,6 +141,17 @@ def mixq_complexity_loss(self):
|
|||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
def mixq_bram_loss(self):
|
||||
memory_sizes = []
|
||||
loss = 0
|
||||
for m in self.modules():
|
||||
if isinstance(m, self.conv_func):
|
||||
loss += m.bram_loss()
|
||||
memory_sizes += [m.memory_size.item()]
|
||||
normalizer = memory_sizes[0]
|
||||
loss /= normalizer
|
||||
return loss
|
||||
|
||||
class UltraNet_ismart(nn.Module):
|
||||
def __init__(self):
|
||||
super(UltraNet_ismart, self).__init__()
|
||||
|
@ -346,6 +368,7 @@ class UltraNet_MixQ(nn.Module):
|
|||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class UltraNet_FixQ(nn.Module):
|
||||
def __init__(self, bitw = '444444444', bita = '444444444'):
|
||||
|
@ -663,6 +686,7 @@ class UltraNetBypass_MixQ(nn.Module):
|
|||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class UltraNetBypass_FixQ(nn.Module):
|
||||
def __init__(self, bitw = '444444444', bita = '444444444'):
|
||||
|
@ -853,6 +877,7 @@ class SkyNet_MixQ(nn.Module):
|
|||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class SkyNetk5_MixQ(nn.Module):
|
||||
def __init__(self, share_weight = False):
|
||||
|
@ -907,6 +932,7 @@ class SkyNetk5_MixQ(nn.Module):
|
|||
|
||||
# def complexity_loss(self):
|
||||
complexity_loss= mixq_complexity_loss
|
||||
bram_loss = mixq_bram_loss
|
||||
|
||||
class SkyNet_FixQ(nn.Module):
|
||||
def __init__(self, bitw='', bita=''):
|
||||
|
|
|
@ -245,6 +245,13 @@ def train():
|
|||
loss_complexity = opt.complexity_decay * model.complexity_loss()
|
||||
loss += loss_complexity * 4.0
|
||||
|
||||
if opt.bram_decay != 0:
|
||||
if hasattr(model, 'module'):
|
||||
loss_bram = opt.bram_decay * model.module.bram_loss()
|
||||
else:
|
||||
loss_bram = opt.bram_decay * model.bram_loss()
|
||||
loss += loss_bram * 4.0
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Optimize accumulated gradient
|
||||
|
@ -264,13 +271,13 @@ def train():
|
|||
|
||||
print('========= architecture =========')
|
||||
if hasattr(model, 'module'):
|
||||
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = model.module.fetch_best_arch()
|
||||
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = model.module.fetch_best_arch()
|
||||
else:
|
||||
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = model.fetch_best_arch()
|
||||
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = model.fetch_best_arch()
|
||||
print('best model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
|
||||
bitops, bita, bitw, dsps))
|
||||
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
|
||||
mixbitops, mixbita, mixbitw, mixdsps))
|
||||
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram_wa:({:.3f},{:.3f})K'.format(
|
||||
mixbitops, mixbita, mixbitw, mixdsps, mixbram_weight, mixbram_cache))
|
||||
print(f'best_weight: {best_arch["best_weight"]} ({"".join([str(x+2) for x in best_arch["best_weight"]])})')
|
||||
print(f'best_activ: {best_arch["best_activ"]} ({"".join([str(x+2) for x in best_arch["best_activ"]])})')
|
||||
|
||||
|
@ -352,6 +359,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument('--var', type=float, help='debug variable')
|
||||
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--bram-decay', '--bd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
|
||||
parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate')
|
||||
parser.add_argument('--no-share', action='store_true', help='no share weight quantization')
|
||||
parser.add_argument('--model', type=str, default='', help='use specific model')
|
||||
|
|
Loading…
Reference in New Issue