[s2s] Adafactor support for builtin trainer (#7522)
This commit is contained in:
parent
d3a9601a11
commit
de4d7b004a
|
@ -52,6 +52,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
|||
predict_with_generate: bool = field(
|
||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||
)
|
||||
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
|||
|
||||
from transformers import Trainer
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||
from transformers.trainer import get_tpu_sampler
|
||||
|
||||
|
||||
|
@ -28,6 +29,43 @@ class Seq2SeqTrainer(Trainer):
|
|||
self.pad_token_id = self.config.pad_token_id
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if self.optimizer is None:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if self.args.adafactor:
|
||||
self.optimizer = Adafactor(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
)
|
||||
|
||||
else:
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
|
||||
)
|
||||
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
|
|
|
@ -91,6 +91,7 @@ def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs
|
|||
"0.1",
|
||||
# "--eval_beams",
|
||||
# "2",
|
||||
"--adafactor",
|
||||
"--task",
|
||||
"translation",
|
||||
"--tgt_lang",
|
||||
|
|
Loading…
Reference in New Issue