Merge pull request #32 from xiaoda99/master
Fix ineffective no_decay bug when using BERTAdam
This commit is contained in:
commit
061eeca84a
|
@ -503,8 +503,8 @@ def main():
|
|||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
|
|
Loading…
Reference in New Issue