Prepare optimizer only when args.do_train is True
This commit is contained in:
parent
3ae8c8be1e
commit
74dbba64bc
|
@ -534,6 +534,7 @@ def main():
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
if args.do_train:
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
|
|
|
@ -763,6 +763,7 @@ def main():
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
if args.do_train:
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
|
|
|
@ -183,6 +183,7 @@ def main():
|
||||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
if args.do_train:
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
|
|
|
@ -922,6 +922,7 @@ def main():
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
if args.do_train:
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
|
|
||||||
# hack to remove pooler, which is not used
|
# hack to remove pooler, which is not used
|
||||||
|
|
|
@ -385,6 +385,7 @@ def main():
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
if args.do_train:
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
|
|
||||||
# hack to remove pooler, which is not used
|
# hack to remove pooler, which is not used
|
||||||
|
|
Loading…
Reference in New Issue