Prepare optimizer only when args.do_train is True

This commit is contained in:
MottoX 2019-05-02 19:09:29 +08:00
parent 3ae8c8be1e
commit 74dbba64bc
5 changed files with 130 additions and 125 deletions

View File

@ -534,6 +534,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [

View File

@ -763,6 +763,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [

View File

@ -183,6 +183,7 @@ def main():
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [

View File

@ -922,6 +922,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used

View File

@ -385,6 +385,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used