[s2s] Adafactor support for builtin trainer (#7522)

This commit is contained in:
Sam Shleifer 2020-10-01 17:27:45 -04:00 committed by GitHub
parent d3a9601a11
commit de4d7b004a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 0 deletions

View File

@ -52,6 +52,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
@dataclass

View File

@ -7,6 +7,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.trainer import get_tpu_sampler
@ -28,6 +29,43 @@ class Seq2SeqTrainer(Trainer):
self.pad_token_id = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if self.args.adafactor:
self.optimizer = Adafactor(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
scale_parameter=False,
relative_step=False,
)
else:
self.optimizer = AdamW(
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None

View File

@ -91,6 +91,7 @@ def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs
"0.1",
# "--eval_beams",
# "2",
"--adafactor",
"--task",
"translation",
"--tgt_lang",