mirror of https://github.com/open-mmlab/mmpose
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
|
|
def change_model(args):
|
|
dis_model = torch.load(args.dis_path, map_location='cpu')
|
|
all_name = []
|
|
if args.two_dis:
|
|
for name, v in dis_model['state_dict'].items():
|
|
if name.startswith('teacher.backbone'):
|
|
all_name.append((name[8:], v))
|
|
elif name.startswith('distill_losses.loss_mgd.down'):
|
|
all_name.append(('head.' + name[24:], v))
|
|
elif name.startswith('teacher.neck'):
|
|
all_name.append((name[8:], v))
|
|
elif name.startswith('student.head'):
|
|
all_name.append((name[8:], v))
|
|
else:
|
|
continue
|
|
else:
|
|
for name, v in dis_model['state_dict'].items():
|
|
if name.startswith('student.'):
|
|
all_name.append((name[8:], v))
|
|
else:
|
|
continue
|
|
state_dict = OrderedDict(all_name)
|
|
dis_model['state_dict'] = state_dict
|
|
|
|
save_keys = ['meta', 'state_dict']
|
|
ckpt_keys = list(dis_model.keys())
|
|
for k in ckpt_keys:
|
|
if k not in save_keys:
|
|
dis_model.pop(k, None)
|
|
|
|
torch.save(dis_model, args.output_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Transfer CKPT')
|
|
parser.add_argument('dis_path', help='dis_model path')
|
|
parser.add_argument('output_path', help='output path')
|
|
parser.add_argument(
|
|
'--two_dis', action='store_true', default=False, help='if two dis')
|
|
args = parser.parse_args()
|
|
change_model(args)
|