diff --git a/anypacking/dsp_packing.py b/anypacking/dsp_packing.py new file mode 100644 index 0000000..f29de90 --- /dev/null +++ b/anypacking/dsp_packing.py @@ -0,0 +1,29 @@ +factors_k11=[ +[12,8,8,6,6,4,4], +[10,8,6,6,4,4,4], +[8,6,6,4,4,4,3], +[6,6,4,4,4,4,2], +[6,4,4,4,2,2,2], +[4,4,4,4,2,2,2], +[4,4,3,2,2,2,2], +] + +factors_k33=[ +[18,15,12,7.5,7.5,6,6], +[15,12,7.5,6,6,6,3], +[12,7.5,6,6,6,6,3], +[9,6,6,6,6,3,3], +[7.5,6,6,4.5,3,3,3], +[6,6,4.5,3,3,3,2.25], +[6,3,3,3,3,3,2], +] + +factors_k55=[ +[20,15,10,7.5,7.5,5,5], +[12.5,10,6.67,5,5,5,3.33], +[10,7.5,5,5,5,5,3.33], +[7.5,6.67,5,5,5,3.33,3.33], +[6.67,5,5,5,3.33,2.5,2.5], +[5,5,5,3.33,2.5,2.5,2.5], +[5,3.33,3.33,3.33,2.5,2.5,2], +] \ No newline at end of file diff --git a/anypacking/quant_module.py b/anypacking/quant_module.py index eaec02d..3252e0a 100644 --- a/anypacking/quant_module.py +++ b/anypacking/quant_module.py @@ -5,40 +5,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter +# load dsp packing factors +from .dsp_packing import * + gaussian_steps = {1: 1.596, 2: 0.996, 3: 0.586, 4: 0.336, 5: 0.190, 6: 0.106, 7: 0.059, 8: 0.032} hwgq_steps = {1: 0.799, 2: 0.538, 3: 0.3217, 4: 0.185, 5: 0.104, 6: 0.058, 7: 0.033, 8: 0.019} - -dsp_factors_k11=[ -[12,8,8,6,6,4,4], -[10,8,6,6,4,4,4], -[8,6,6,4,4,4,3], -[6,6,4,4,4,4,2], -[6,4,4,4,2,2,2], -[4,4,4,4,2,2,2], -[4,4,3,2,2,2,2], -] - -dsp_factors_k33=[ -[18,15,12,7.5,7.5,6,6], -[15,12,7.5,6,6,6,3], -[12,7.5,6,6,6,6,3], -[9,6,6,6,6,3,3], -[7.5,6,6,4.5,3,3,3], -[6,6,4.5,3,3,3,2.25], -[6,3,3,3,3,3,2], -] - -dsp_factors_k55=[ -[20,15,10,7.5,7.5,5,5], -[12.5,10,6.67,5,5,5,3.33], -[10,7.5,5,5,5,5,3.33], -[7.5,6.67,5,5,5,3.33,3.33], -[6.67,5,5,5,3.33,2.5,2.5], -[5,5,5,3.33,2.5,2.5,2.5], -[5,3.33,3.33,3.33,2.5,2.5,2], -] - class _gauss_quantize_sym(torch.autograd.Function): @staticmethod @@ -394,7 +366,7 @@ class MixActivConv2d(nn.Module): out = self.mix_weight(out) return out - def complexity_loss_old(self): + def complexity_loss_trivial(self): sw = F.softmax(self.mix_activ.alpha_activ, dim=0) mix_abit = 0 abits = self.mix_activ.bits @@ -416,16 +388,16 @@ class MixActivConv2d(nn.Module): wbits = self.mix_weight.bits if self.kernel_size == 1: - dsp_factors = dsp_factors_k11 + factors = factors_k11 elif self.kernel_size == 3: - dsp_factors = dsp_factors_k33 + factors = factors_k33 elif self.kernel_size == 5: - dsp_factors = dsp_factors_k55 + factors = factors_k55 else: raise NotImplementedError for i in range(len(wbits)): for j in range(len(abits)): - mix_scale += sw[i] * sa[j] / dsp_factors[wbits[i]-2][abits[j]-2] + mix_scale += sw[i] * sa[j] / factors[wbits[i]-2][abits[j]-2] complexity = self.size_product.item() * 64 * mix_scale return complexity @@ -493,21 +465,21 @@ class MixActivConv2d(nn.Module): bitw = self.param_size * wbits[best_weight] if self.kernel_size == 1: - dsp_factors = dsp_factors_k11 + factors = factors_k11 elif self.kernel_size == 3: - dsp_factors = dsp_factors_k33 + factors = factors_k33 elif self.kernel_size == 5: - dsp_factors = dsp_factors_k55 + factors = factors_k55 else: raise NotImplementedError - dsps = size_product / dsp_factors[wbits[best_weight]-2][abits[best_activ]-2] + dsps = size_product / factors[wbits[best_weight]-2][abits[best_activ]-2] mixbitops = size_product * mix_abit * mix_wbit mixbita = memory_size * mix_abit mixbitw = self.param_size * mix_wbit mixdsps = 0 for i in range(len(wbits)): for j in range(len(abits)): - mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors[wbits[i]-2][abits[j]-2] + mixdsps += prob_weight[i] * prob_activ[j] / factors[wbits[i]-2][abits[j]-2] mixdsps *= size_product mixbram_weight = self.param_size * 1e3 * mix_wbit # kbit mixbram_cache = bram_sw * mix_abit # kbit @@ -595,7 +567,7 @@ class MixActivLinear(nn.Module): wbits = self.mix_weight.bits for i in range(len(wbits)): for j in range(len(abits)): - mix_scale += sw[i] * sa[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2] + mix_scale += sw[i] * sa[j] / factors_k11[wbits[i]-2][abits[j]-2] complexity = self.size_product.item() * 64 * mix_scale return complexity @@ -627,13 +599,13 @@ class MixActivLinear(nn.Module): bitops = size_product * abits[best_activ] * wbits[best_weight] bita = memory_size * abits[best_activ] bitw = self.param_size * wbits[best_weight] - dsps = size_product / dsp_factors_k11[wbits[best_weight]-2][abits[best_activ]-2] + dsps = size_product / factors_k11[wbits[best_weight]-2][abits[best_activ]-2] mixbitops = size_product * mix_abit * mix_wbit mixbita = memory_size * mix_abit mixbitw = self.param_size * mix_wbit mixdsps = 0 for i in range(len(wbits)): for j in range(len(abits)): - mixdsps += prob_weight[i] * prob_activ[j] / dsp_factors_k11[wbits[i]-2][abits[j]-2] + mixdsps += prob_weight[i] * prob_activ[j] / factors_k11[wbits[i]-2][abits[j]-2] mixdsps *= size_product return best_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps diff --git a/cifar/main_train.py b/cifar/main_train.py index 7595f81..8bf7149 100644 --- a/cifar/main_train.py +++ b/cifar/main_train.py @@ -105,22 +105,36 @@ def train(): print('Finished Training') + with open('results.csv', 'a') as f: + print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "% + (opt.name,epochs-1,epochs,opt.bitw,opt.bita,macc*100,(test_acc+test_best_acc)/2, + int(round(bops)), dsps), file=f) + # torch.save(net.state_dict(), 'lenet_cifar10.pth') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-n', '--name', default='VGG_tiny_FixQ', help='result and weight file name') parser.add_argument('-w', '--weights', default=None, help='weights path') - parser.add_argument('-e', '--epochs', type=int, default=40) + parser.add_argument('-e', '--epochs', type=int, default=200) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--bypass', action='store_true', help='use bypass model') parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)') parser.add_argument('--lr', type=float, default=0.03) + parser.add_argument('--mixm', type=str) parser.add_argument('--bitw', type=str, default='') parser.add_argument('--bita', type=str, default='') opt = parser.parse_args() + + if opt.mixm is not None: + wmix = torch.load('weights/%s.pt'%opt.mixm) + opt.bitw = wmix['extra']['bestw'] + opt.bita = wmix['extra']['besta'] + del wmix + print(opt) + wdir = 'weights' + os.sep # weights dir last = wdir + '%s_last.pt'%opt.name diff --git a/cifar/models.py b/cifar/models.py index 390cc99..661dbf6 100644 --- a/cifar/models.py +++ b/cifar/models.py @@ -125,22 +125,22 @@ class VGG_tiny_MixQ(nn.Module): nn.BatchNorm2d(64), conv_func(64, 64, **conv_kwargs, **qspace), # 1 - self.pooling, nn.BatchNorm2d(64), + self.pooling, conv_func(64, 128, **conv_kwargs, **qspace), # 2 nn.BatchNorm2d(128), conv_func(128, 128, **conv_kwargs, **qspace), # 3 - self.pooling, nn.BatchNorm2d(128), + self.pooling, conv_func(128, 256, **conv_kwargs, **qspace), # 4 nn.BatchNorm2d(256), conv_func(256, 256, **conv_kwargs, **qspace), # 5 - self.pooling, nn.BatchNorm2d(256), + self.pooling, nn.Flatten(), qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8) @@ -156,7 +156,7 @@ class VGG_tiny_MixQ(nn.Module): best_arch = None for m in self.modules(): if isinstance(m, self.conv_func): - layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps = m.fetch_best_arch(layer_idx) + layer_arch, bitops, bita, bitw, mixbitops, mixbita, mixbitw, dsps, mixdsps, mixbram_weight, mixbram_cache = m.fetch_best_arch(layer_idx) if best_arch is None: best_arch = layer_arch else: @@ -186,6 +186,17 @@ class VGG_tiny_MixQ(nn.Module): normalizer = size_product[0].item() loss /= normalizer return loss + + def complexity_loss_trivial(self): + size_product = [] + loss = 0 + for m in self.modules(): + if isinstance(m, self.conv_func): + loss += m.complexity_loss_trivial() + size_product += [m.size_product] + normalizer = size_product[0].item() + loss /= normalizer + return loss class VGG_tiny_FixQ(nn.Module): def __init__(self, num_classes=10, bitw = '444444', bita = '844444'): @@ -211,22 +222,22 @@ class VGG_tiny_FixQ(nn.Module): nn.BatchNorm2d(64), conv_func(64, 64, **conv_kwargs, wbit=bitw[1], abit=bita[1]), # 1 - nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(64), + nn.MaxPool2d(kernel_size=2, stride=2), conv_func(64, 128, **conv_kwargs, wbit=bitw[2], abit=bita[2]), # 2 nn.BatchNorm2d(128), conv_func(128, 128, **conv_kwargs, wbit=bitw[3], abit=bita[3]), # 3 - nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(128), + nn.MaxPool2d(kernel_size=2, stride=2), conv_func(128, 256, **conv_kwargs, wbit=bitw[4], abit=bita[4]), # 4 nn.BatchNorm2d(256), conv_func(256, 256, **conv_kwargs, wbit=bitw[5], abit=bita[5]), # 5 - nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(256), + nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), qm.QuantActivLinear(256*4*4, num_classes, bias=True, wbit=8, abit=8) @@ -245,7 +256,7 @@ class VGG_tiny_FixQ(nn.Module): bitops = size_product * m.abit * m.wbit bita = m.memory_size.item() * m.abit bitw = m.param_size * m.wbit - dsps = size_product / qm.dsp_factors[m.wbit-2][m.abit-2] + dsps = size_product / qm.dsp_factors_k33[m.wbit-2][m.abit-2] weight_shape = list(m.conv.weight.shape) print('idx {} with shape {}, bitops: {:.3f}M * {} * {}, memory: {:.3f}K * {}, ' 'param: {:.3f}M * {}, dsps: {:.3f}M'.format(layer_idx, weight_shape, size_product, m.abit, diff --git a/cifar/search_train.py b/cifar/search_train.py index a9ae1e8..fd3bea6 100644 --- a/cifar/search_train.py +++ b/cifar/search_train.py @@ -75,8 +75,9 @@ def train(): correct = (predicted == labels).sum().item() loss = criterion(outputs, labels) - if opt.complexity_decay != 0: - loss_complexity = opt.complexity_decay * model.complexity_loss() + if opt.complexity_decay != 0 or opt.complexity_decay_trivial!=0: + loss_complexity = opt.complexity_decay * model.complexity_loss() + \ + opt.complexity_decay_trivial * model.complexity_loss_trivial() loss += loss_complexity loss.backward() @@ -94,6 +95,8 @@ def train(): bitops, bita, bitw, dsps)) print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M'.format( mixbitops, mixbita, mixbitw, mixdsps)) + bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]]) + besta_str = "".join([str(x+2) for x in best_arch["best_activ"]]) print(f'best_weight: {best_arch["best_weight"]}') print(f'best_activ: {best_arch["best_activ"]}') @@ -115,7 +118,7 @@ def train(): model) is nn.parallel.DistributedDataParallel else model.state_dict(), 'optimizer': None if final_epoch else optimizer.state_dict(), 'arch_optimizer': None if final_epoch else arch_optimizer.state_dict(), - 'extra': {'time': time.ctime(), 'name': opt.name}} + 'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}} # Save last checkpoint torch.save(chkpt, wdir + '%s_last.pt'%opt.name) @@ -124,6 +127,12 @@ def train(): print('Finished Training') + with open('results.csv', 'a') as f: + print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"% + (opt.name,epochs-1,epochs,macc*100,(test_acc+test_best_acc)/2, + bestw_str,besta_str, + int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f) + # torch.save(net.state_dict(), 'lenet_cifar10.pth') if __name__ == '__main__': @@ -135,6 +144,7 @@ if __name__ == '__main__': parser.add_argument('--name', default='', help='result and weight file name') parser.add_argument('--noshare', action='store_true', help='no share weight') parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)') + parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay w/o hardware-aware') parser.add_argument('--lra', '--learning-rate-alpha', default=0.1, type=float, metavar='LR', help='initial alpha learning rate') opt = parser.parse_args() diff --git a/dacsdc/main_train.py b/dacsdc/main_train.py index 18dad32..f3c7ee1 100644 --- a/dacsdc/main_train.py +++ b/dacsdc/main_train.py @@ -244,6 +244,8 @@ def train(): # Update scheduler scheduler.step() + train_iou = mloss[2] + # Process epoch results final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP @@ -293,6 +295,11 @@ def train(): print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) dist.destroy_process_group() if torch.cuda.device_count() > 1 else None torch.cuda.empty_cache() + + with open('results.csv', 'a') as f: + print("fixed,%s,%d/%d, , ,%s,%s,%.1f,%.1f, , , ,%d, ,%.3f, "% + (opt.name,epochs-1,epochs,opt.bitw,opt.bita,train_iou*100,(test_iou+test_best_iou)*50, + int(round(bops)), dsps), file=f) return results @@ -317,12 +324,19 @@ if __name__ == '__main__': parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)') parser.add_argument('--adam', action='store_true', help='use adam optimizer') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') + parser.add_argument('--mixm', type=str) parser.add_argument('--bitw', type=str, default='') parser.add_argument('--bita', type=str, default='') parser.add_argument('--var', type=float, help='debug variable') parser.add_argument('--model', type=str, default='', help='use specific model') opt = parser.parse_args() + + if opt.mixm is not None: + wmix = torch.load('weights/%s.pt'%opt.mixm) + opt.bitw = wmix['extra']['bestw'] + opt.bita = wmix['extra']['besta'] + del wmix last = wdir + 'last_%s.pt'%opt.name opt.weights = last if opt.resume else opt.weights print(opt) diff --git a/dacsdc/mymodel.py b/dacsdc/mymodel.py index b1544aa..f0e0b5e 100644 --- a/dacsdc/mymodel.py +++ b/dacsdc/mymodel.py @@ -152,6 +152,17 @@ def mixq_bram_loss(self): loss /= normalizer return loss +def mixq_complexity_loss_trivial(self): + size_product = [] + loss = 0 + for m in self.modules(): + if isinstance(m, self.conv_func): + loss += m.complexity_loss_trivial() + size_product += [m.size_product] + normalizer = size_product[0].item() + loss /= normalizer + return loss + class UltraNet_ismart(nn.Module): def __init__(self): super(UltraNet_ismart, self).__init__() @@ -367,7 +378,8 @@ class UltraNet_MixQ(nn.Module): fetch_best_arch = mixq_fetch_best_arch # def complexity_loss(self): - complexity_loss= mixq_complexity_loss + complexity_loss = mixq_complexity_loss + complexity_loss_trivial = mixq_complexity_loss_trivial bram_loss = mixq_bram_loss class UltraNet_FixQ(nn.Module): @@ -685,7 +697,8 @@ class UltraNetBypass_MixQ(nn.Module): fetch_best_arch = mixq_fetch_best_arch # def complexity_loss(self): - complexity_loss= mixq_complexity_loss + complexity_loss = mixq_complexity_loss + mixq_complexity_loss_trivial = mixq_complexity_loss_trivial bram_loss = mixq_bram_loss class UltraNetBypass_FixQ(nn.Module): @@ -876,7 +889,8 @@ class SkyNet_MixQ(nn.Module): fetch_best_arch = mixq_fetch_best_arch # def complexity_loss(self): - complexity_loss= mixq_complexity_loss + complexity_loss = mixq_complexity_loss + mixq_complexity_loss_trivial = mixq_complexity_loss_trivial bram_loss = mixq_bram_loss class SkyNetk5_MixQ(nn.Module): diff --git a/dacsdc/pareto_train.py b/dacsdc/pareto_train.py new file mode 100644 index 0000000..c8f5252 --- /dev/null +++ b/dacsdc/pareto_train.py @@ -0,0 +1,30 @@ +import os +import argparse + +cds = { +'cd':['3e-5', '6e-5', '1e-4', '2e-4', '3e-4'], +'cdt':['1e-5', '2e-5', '3e-5', '6e-5', '1e-4'], +} + +def search_train(): + for cd in cds[opt.arg]: + name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','') + os.system('python search_train.py --name %s --cd %s'%('f'+name, cd)) + +def main_train(): + for cd in cds[opt.arg]: + name = '%d_%s_'%(opt.it, opt.arg)+cd.replace('-','').replace('.','') + os.system('python main_train.py --name %s --mixm %s'%('x'+name, 'f'+name+'_last')) + +parser = argparse.ArgumentParser() +parser.add_argument('--search', action='store_true') +parser.add_argument('--main', action='store_true') +parser.add_argument('--it', type=int) +parser.add_argument('--arg', type=str) +opt = parser.parse_args() + +if opt.search: + search_train() + +if opt.main: + main_train() diff --git a/dacsdc/search_train.py b/dacsdc/search_train.py index 917019e..c964055 100644 --- a/dacsdc/search_train.py +++ b/dacsdc/search_train.py @@ -239,16 +239,15 @@ def train(): # complexity penalty if opt.complexity_decay != 0: - if hasattr(model, 'module'): - loss_complexity = opt.complexity_decay * model.module.complexity_loss() - else: - loss_complexity = opt.complexity_decay * model.complexity_loss() + loss_complexity = opt.complexity_decay * model.complexity_loss() loss += loss_complexity * 4.0 + + if opt.complexity_decay_trivial != 0: + loss_complexity_trivial = opt.complexity_decay_trivial * model.complexity_loss_trivial() + loss += loss_complexity_trivial * 4.0 if opt.bram_decay != 0: if hasattr(model, 'module'): - loss_bram = opt.bram_decay * model.module.bram_loss() - else: loss_bram = opt.bram_decay * model.bram_loss() loss += loss_bram * 4.0 @@ -278,13 +277,18 @@ def train(): bitops, bita, bitw, dsps)) print('expected model with bitops: {:.3f}M, bita: {:.3f}K, bitw: {:.3f}M, dsps: {:.3f}M, bram_wa:({:.3f},{:.3f})K'.format( mixbitops, mixbita, mixbitw, mixdsps, mixbram_weight, mixbram_cache)) - print(f'best_weight: {best_arch["best_weight"]} ({"".join([str(x+2) for x in best_arch["best_weight"]])})') - print(f'best_activ: {best_arch["best_activ"]} ({"".join([str(x+2) for x in best_arch["best_activ"]])})') - + + bestw_str = "".join([str(x+2) for x in best_arch["best_weight"]]) + besta_str = "".join([str(x+2) for x in best_arch["best_activ"]]) + print(f'best_weight: {best_arch["best_weight"]}') + print(f'best_activ: {best_arch["best_activ"]}') + # Update scheduler scheduler.step() arch_scheduler.step() + train_iou = mloss[2] + # Process epoch results final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP @@ -314,7 +318,7 @@ def train(): 'model': model.module.state_dict() if type( model) is nn.parallel.DistributedDataParallel else model.state_dict(), 'optimizer': None if final_epoch else optimizer.state_dict(), - 'extra': {'time': time.ctime(), 'name': opt.name}} + 'extra': {'time': time.ctime(), 'name': opt.name, 'bestw': bestw_str, 'besta': besta_str}} # Save last checkpoint torch.save(chkpt, wdir + '%s_last.pt'%opt.name) @@ -333,6 +337,12 @@ def train(): print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) dist.destroy_process_group() if torch.cuda.device_count() > 1 else None torch.cuda.empty_cache() + + with open('results.csv', 'a') as f: + print("mixed,%s,%d/%d, , , , ,%.1f,%.1f, ,%s,%s,%d,%d,%.3f,%.3f"% + (opt.name,epochs-1,epochs,train_iou*100,(test_iou+test_best_iou)*50, + bestw_str,besta_str, + int(round(bitops)), int(round(mixbitops)), dsps, mixdsps), file=f) return results @@ -359,6 +369,7 @@ if __name__ == '__main__': parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--var', type=float, help='debug variable') parser.add_argument('--complexity-decay', '--cd', default=0, type=float, metavar='W', help='complexity decay (default: 0)') + parser.add_argument('--complexity-decay-trivial', '--cdt', default=0, type=float, metavar='W', help='complexity decay (default: 0)') parser.add_argument('--bram-decay', '--bd', default=0, type=float, metavar='W', help='complexity decay (default: 0)') parser.add_argument('--lra', '--learning-rate-alpha', default=0.01, type=float, metavar='LR', help='initial alpha learning rate') parser.add_argument('--no-share', action='store_true', help='no share weight quantization')