diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 8be7ca05c3..5660f38360 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -53,6 +53,16 @@ class Seq2SeqTrainingArguments(TrainingArguments): default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} ) adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) + encoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} + ) + decoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} + ) + dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) + attention_dropout: Optional[float] = field( + default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} + ) @dataclass @@ -179,6 +189,13 @@ def main(): model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) + + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") + for p in extra_model_params: + if getattr(training_args, p, None): + assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute" + setattr(config, p, getattr(training_args, p)) + tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 3d143a09bb..0f585eb262 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -6,6 +6,7 @@ from torch import nn from torch.utils.data import DistributedSampler, RandomSampler from transformers import Trainer +from transformers.configuration_fsmt import FSMTConfig from transformers.file_utils import is_torch_tpu_available from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup from transformers.trainer import get_tpu_sampler @@ -26,8 +27,7 @@ class Seq2SeqTrainer(Trainer): self.config = config self.data_args = data_args self.max_gen_length = data_args.val_max_target_length - self.pad_token_id = self.config.pad_token_id - self.vocab_size = self.config.vocab_size + self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size def create_optimizer_and_scheduler(self, num_training_steps: int): """ @@ -87,18 +87,18 @@ class Seq2SeqTrainer(Trainer): labels = inputs.pop("labels") outputs = model(**inputs, use_cache=False) logits = outputs[0] - return self._compute_loss(logits, labels, ignore_index=self.pad_token_id) + return self._compute_loss(logits, labels) - def _compute_loss(self, logits, labels, ignore_index): + def _compute_loss(self, logits, labels): if self.args.label_smoothing == 0: # Same behavior as modeling_bart.py - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) assert logits.shape[-1] == self.vocab_size loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( - lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index + lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id ) return loss @@ -137,14 +137,12 @@ class Seq2SeqTrainer(Trainer): max_length=self.max_gen_length, ) # in case the batch is shorter than max length, the output should be padded - generated_tokens = self._pad_tensors_to_max_len( - generated_tokens, self.max_gen_length, self.pad_token_id - ) + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length) labels_out = inputs.get("labels") # Call forward again to get loss # TODO: avoidable? outputs = model(**inputs, use_cache=False) - loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id) + loss = self._compute_loss(outputs[1], labels_out) loss = loss.mean().item() if self.args.prediction_loss_only: return (loss, None, None) @@ -152,11 +150,11 @@ class Seq2SeqTrainer(Trainer): logits = generated_tokens if self.args.predict_with_generate else outputs[1] labels_out = labels_out.detach() - labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id) + labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length) return (loss, logits.detach(), labels) - def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id): - padded_tensor = pad_token_id * torch.ones( + def _pad_tensors_to_max_len(self, tensor, max_length): + padded_tensor = self.config.pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device ) padded_tensor[:, : tensor.shape[-1]] = tensor diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 39d44d774a..517e76b232 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -26,7 +26,7 @@ def test_finetune_trainer(): def test_finetune_trainer_slow(): # TODO(SS): This will fail on devices with more than 1 GPU. # There is a missing call to __init__process_group somewhere - output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3) + output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) # Check metrics logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 18928e3dd3..d19cf63549 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -269,7 +269,11 @@ class Seq2SeqDataCollator: ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." self.data_args = data_args self.tpu_num_cores = tpu_num_cores - self.add_prefix_space = isinstance(tokenizer, BartTokenizer) + self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)} + if data_args.src_lang is not None: + self.dataset_kwargs["src_lang"] = data_args.src_lang + if data_args.tgt_lang is not None: + self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang def __call__(self, batch) -> Dict[str, torch.Tensor]: if hasattr(self.tokenizer, "prepare_seq2seq_batch"): @@ -310,14 +314,12 @@ class Seq2SeqDataCollator: def _encode(self, batch) -> Dict[str, torch.Tensor]: batch_encoding = self.tokenizer.prepare_seq2seq_batch( [x["src_texts"] for x in batch], - src_lang=self.data_args.src_lang, tgt_texts=[x["tgt_texts"] for x in batch], - tgt_lang=self.data_args.tgt_lang, max_length=self.data_args.max_source_length, max_target_length=self.data_args.max_target_length, padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack return_tensors="pt", - add_prefix_space=self.add_prefix_space, + **self.dataset_kwargs, ) return batch_encoding.data