add bram loss(WIP)

This commit is contained in:
fffasttime 2023-02-20 17:11:51 +08:00
parent ed6b0ec2fa
commit 87f943e14a
4 changed files with 99 additions and 21 deletions

View File

@ -209,10 +209,14 @@ class QuantActivConv2d(nn.Module):
self.kernel_size = kwargs['kernel_size']
if 'groups' in kwargs: groups = kwargs['groups']
else: groups = 1
self.inplane = inplane
self.outplane = outplane
self.groups = groups
self.param_size = inplane * outplane * kernel_size * 1e-6 / groups
self.filter_size = self.param_size / float(stride ** 2.0)
self.register_buffer('size_product', torch.tensor(0, dtype=torch.float))
self.register_buffer('memory_size', torch.tensor(0, dtype=torch.float))
self.register_buffer('in_width', torch.tensor(0, dtype=torch.float))
def forward(self, input):
in_shape = input.shape
@ -221,6 +225,8 @@ class QuantActivConv2d(nn.Module):
tmp = torch.tensor(self.filter_size * in_shape[-1] * in_shape[-2], dtype=torch.float)
self.size_product.copy_(tmp)
out = self.activ(input)
tmp = torch.tensor(input.shape[3], dtype=torch.float)
self.in_width.copy_(tmp)
## print('ii',input[0,0,:,0]/self.activ.step)
## print('convi', torch.round(out[0,0,:,0]/self.activ.step).int())
## wstd = self.conv.weight.std()
@ -363,10 +369,14 @@ class MixActivConv2d(nn.Module):
if 'groups' in kwargs: groups = kwargs['groups']
else: groups = 1
self.inplane = inplane
self.outplane = outplane
self.groups = groups
self.param_size = inplane * outplane * kernel_size * 1e-6 / groups
self.filter_size = self.param_size / float(stride ** 2.0)
self.register_buffer('size_product', torch.tensor(0, dtype=torch.float))
self.register_buffer('memory_size', torch.tensor(0, dtype=torch.float))
self.register_buffer('in_width', torch.tensor(0, dtype=torch.float))
def forward(self, input):
in_shape = input.shape
@ -374,6 +384,8 @@ class MixActivConv2d(nn.Module):
self.memory_size.copy_(tmp)
tmp = torch.tensor(self.filter_size * in_shape[-1] * in_shape[-2], dtype=torch.float)
self.size_product.copy_(tmp)
tmp = torch.tensor(input.shape[3], dtype=torch.float)
self.in_width.copy_(tmp)
out = self.mix_activ(input)
out = self.mix_weight(out)
return out
@ -412,7 +424,30 @@ class MixActivConv2d(nn.Module):
mix_scale += sw[i] * sa[j] / dsp_factors[wbits[i]-2][abits[j]-2]
complexity = self.size_product.item() * 64 * mix_scale
return complexity
def bram_loss(self):
sa = F.softmax(self.mix_activ.alpha_activ, dim=0)
abits = self.mix_activ.bits
sw = F.softmax(self.mix_weight.alpha_weight, dim=0)
wbits = self.mix_weight.bits
if self.kernel_size == 1:
bram_sw = 2 * self.in_width.item() * self.inplane
else: # sliding window size
bram_sw = (self.kernel_size+1)*self.in_width.item()*self.inplane
bram_sw *= 1e-3
mix_wbit, mix_abit = 0, 0
for i in range(len(wbits)):
mix_wbit += sw[i] * wbits[i]
for i in range(len(abits)):
mix_abit += sa[i] * abits[i]
bram_weight = self.param_size * 1e3 * mix_wbit # kbit
bram_cache = bram_sw * mix_abit # kbit
bram = (bram_weight + bram_cache) * 64
return bram
def fetch_best_arch(self, layer_idx):
size_product = float(self.size_product.cpu().numpy())
@ -435,9 +470,16 @@ class MixActivConv2d(nn.Module):
weight_shape = list(self.mix_weight.conv.weight.shape)
else:
weight_shape = list(self.mix_weight.conv_list[0].weight.shape)
if self.kernel_size == 1:
bram_sw = 2 * self.in_width.item() * self.inplane
else:
bram_sw = (self.kernel_size+1)*self.in_width.item()*self.inplane*self.outplane/self.groups
bram_sw *= 1e-3
print('idx {} with shape {}, activ alpha: {}, comp: {:.3f}M * {:.3f} * {:.3f}, '
'memory: {:.3f}K * {:.3f}'.format(layer_idx, weight_shape, prob_activ, size_product,
mix_abit, mix_wbit, memory_size, mix_abit))
'memory: {:.3f}K * {:.3f}, cache: {:.3f}K'.format(layer_idx, weight_shape, prob_activ, size_product,
mix_abit, mix_wbit, memory_size, mix_abit, bram_sw))
print('idx {} with shape {}, weight alpha: {}, comp: {:.3f}M * {:.3f} * {:.3f}, '
'param: {:.3f}M * {:.3f}'.format(layer_idx, weight_shape, prob_weight, size_product,
mix_abit, mix_wbit, self.param_size, mix_wbit))
@ -463,8 +505,10 @@ class MixActivConv2d(nn.Module):
for j in range(len(abits)):
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
mixdsps *= size_product
mixbram_weight = self.param_size * 1e3 * mix_wbit # kbit
mixbram_cache = bram_sw * mix_abit # kbit
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache
class SharedMixQuantLinear(nn.Module):
@ -586,6 +630,6 @@ class MixActivLinear(nn.Module):
mixdsps = 0
for i in range(len(wbits)):
for j in range(len(abits)):
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2]
mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2]
mixdsps *= size_product
return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps

View File

@ -152,6 +152,13 @@ def train():
num_workers=0,
pin_memory=True,
collate_fn=testset.collate_fn)
test.test(batch_size=batch_size,
img_size=img_size_test,
model=model,
dataloader=testloader) # make forward
bops, bita, bitw, dsps, brams = model.fetch_arch_info()
print('model with bops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram: {:.3f}K'.format(bops, bita, bitw, dsps, brams))
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size, batch_size,
@ -181,13 +188,6 @@ def train():
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
print('Using %g dataloader workers' % nw)
print('Starting training for %g epochs...' % epochs)
test.test(batch_size=batch_size,
img_size=img_size_test,
model=model,
dataloader=testloader) # make forward
bops, bita, bitw, dsps = model.fetch_arch_info()
print('model with bops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(bops, bita, bitw, dsps))
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train()

View File

@ -58,7 +58,7 @@ class YOLOLayer(nn.Module):
return io.view(bs, -1, self.no), p
def fixq_fetch_arch_info(self):
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
sum_bitops, sum_bita, sum_bitw, sum_dsps, sum_bram = 0, 0, 0, 0, 0
layer_idx = 0
for m in self.modules():
if isinstance(m, self.conv_func):
@ -79,27 +79,36 @@ def fixq_fetch_arch_info(self):
dsp_factors = qm.dsp_factors_k55
else:
raise NotImplementedError
if m.kernel_size == 1:
bram_sw = 2 * m.in_width.item() * m.inplane
else:
bram_sw = (m.kernel_size+1)*m.in_width.item()*m.inplane
bram_sw *= 1e-3
dsps = size_product / dsp_factors[m.wbit-2][m.abit-2]
weight_shape = list(m.conv.weight.shape)
print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, '
'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit,
m.wbit, memory_size, m.abit, m.param_size, m.wbit, dsps))
'param: {:.3f}M * {}, dsps: {:.3f}M, bram(wa|waf):({:.2f},{:.2f}|{:.1f},{:.1f},{:.1f})K'.format(layer_idx, weight_shape, size_product, m.abit,
m.wbit, memory_size, m.abit, m.param_size, m.wbit, dsps,
m.param_size * 1e3, bram_sw, bitw*1e3, bram_sw*m.abit, bitw*1e3 + bram_sw*m.abit))
sum_bitops += bitops
sum_bita += bita
sum_bitw += bitw
sum_dsps += dsps
sum_bram += bitw*1e3 + bram_sw*m.abit
layer_idx += 1
return sum_bitops, sum_bita, sum_bitw, sum_dsps
return sum_bitops, sum_bita, sum_bitw, sum_dsps, sum_bram
def mixq_fetch_best_arch(self):
sum_bitops, sum_bita, sum_bitw, sum_dsps = 0, 0, 0, 0
sum_mixbitops, sum_mixbita, sum_mixbitw, sum_mixdsps = 0, 0, 0, 0
sum_mixbram_weight, sum_mixbram_cache = 0, 0
layer_idx = 0
best_arch = None
for m in self.modules():
if isinstance(m, self.conv_func):
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx)
layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = m.fetch_best_arch(layer_idx)
if best_arch is None:
best_arch = layer_arch
else:
@ -116,8 +125,10 @@ def mixq_fetch_best_arch(self):
sum_mixbitw += mixbitw
sum_dsps += dsps
sum_mixdsps += mixdsps
sum_mixbram_weight += mixbram_weight
sum_mixbram_cache += mixbram_cache
layer_idx += 1
return best_arch, sum_bitops, sum_bita, sum_bitw, sum_mixbitops, sum_mixbita, sum_mixbitw, sum_dsps, sum_mixdsps
return best_arch, sum_bitops, sum_bita, sum_bitw, sum_mixbitops, sum_mixbita, sum_mixbitw, sum_dsps, sum_mixdsps, sum_mixbram_weight, sum_mixbram_cache
def mixq_complexity_loss(self):
size_product = []
@ -130,6 +141,17 @@ def mixq_complexity_loss(self):
loss /= normalizer
return loss
def mixq_bram_loss(self):
memory_sizes = []
loss = 0
for m in self.modules():
if isinstance(m, self.conv_func):
loss += m.bram_loss()
memory_sizes += [m.memory_size.item()]
normalizer = memory_sizes[0]
loss /= normalizer
return loss
class UltraNet_ismart(nn.Module):
def __init__(self):
super(UltraNet_ismart, self).__init__()
@ -346,6 +368,7 @@ class UltraNet_MixQ(nn.Module):
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
bram_loss = mixq_bram_loss
class UltraNet_FixQ(nn.Module):
def __init__(self, bitw = '444444444', bita = '444444444'):
@ -663,6 +686,7 @@ class UltraNetBypass_MixQ(nn.Module):
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
bram_loss = mixq_bram_loss
class UltraNetBypass_FixQ(nn.Module):
def __init__(self, bitw = '444444444', bita = '444444444'):
@ -853,6 +877,7 @@ class SkyNet_MixQ(nn.Module):
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
bram_loss = mixq_bram_loss
class SkyNetk5_MixQ(nn.Module):
def __init__(self, share_weight = False):
@ -907,6 +932,7 @@ class SkyNetk5_MixQ(nn.Module):
# def complexity_loss(self):
complexity_loss= mixq_complexity_loss
bram_loss = mixq_bram_loss
class SkyNet_FixQ(nn.Module):
def __init__(self, bitw='', bita=''):

View File

@ -245,6 +245,13 @@ def train():
loss_complexity = opt.complexity_decay * model.complexity_loss()
loss += loss_complexity * 4.0
if opt.bram_decay != 0:
if hasattr(model, 'module'):
loss_bram = opt.bram_decay * model.module.bram_loss()
else:
loss_bram = opt.bram_decay * model.bram_loss()
loss += loss_bram * 4.0
loss.backward()
# Optimize accumulated gradient
@ -264,13 +271,13 @@ def train():
print('========= architecture =========')
if hasattr(model, 'module'):
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = model.module.fetch_best_arch()
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = model.module.fetch_best_arch()
else:
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = model.fetch_best_arch()
best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = model.fetch_best_arch()
print('best model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
bitops, bita, bitw, dsps))
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format(
mixbitops, mixbita, mixbitw, mixdsps))
print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram_wa:({:.3f},{:.3f})K'.format(
mixbitops, mixbita, mixbitw, mixdsps, mixbram_weight, mixbram_cache))
print(f'best_weight: {best_arch["best_weight"]} ({"".join([str(x+2) for x in best_arch["best_weight"]])})')
print(f'best_activ: {best_arch["best_activ"]} ({"".join([str(x+2) for x in best_arch["best_activ"]])})')
@ -352,6 +359,7 @@ if __name__ == '__main__':
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--var', type=float, help='debug variable')
parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--bram-decay', '--bd', default=0, type=float, metavar='W', help='complexity decay (default: 0)')
parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate')
parser.add_argument('--no-share', action='store_true', help='no share weight quantization')
parser.add_argument('--model', type=str, default='', help='use specific model')