[s2s] add config params like Dropout in Seq2SeqTrainingArguments (#7532)
This commit is contained in:
parent
9bdce3a4f9
commit
99cb924bfb
|
@ -53,6 +53,16 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
||||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
)
|
)
|
||||||
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
|
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
|
||||||
|
encoder_layerdrop: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."}
|
||||||
|
)
|
||||||
|
decoder_layerdrop: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."}
|
||||||
|
)
|
||||||
|
dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."})
|
||||||
|
attention_dropout: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -179,6 +189,13 @@ def main():
|
||||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||||
|
for p in extra_model_params:
|
||||||
|
if getattr(training_args, p, None):
|
||||||
|
assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute"
|
||||||
|
setattr(config, p, getattr(training_args, p))
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torch import nn
|
||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
from transformers.configuration_fsmt import FSMTConfig
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||||
from transformers.trainer import get_tpu_sampler
|
from transformers.trainer import get_tpu_sampler
|
||||||
|
@ -26,8 +27,7 @@ class Seq2SeqTrainer(Trainer):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.max_gen_length = data_args.val_max_target_length
|
self.max_gen_length = data_args.val_max_target_length
|
||||||
self.pad_token_id = self.config.pad_token_id
|
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||||
self.vocab_size = self.config.vocab_size
|
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
|
@ -87,18 +87,18 @@ class Seq2SeqTrainer(Trainer):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
outputs = model(**inputs, use_cache=False)
|
outputs = model(**inputs, use_cache=False)
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
|
return self._compute_loss(logits, labels)
|
||||||
|
|
||||||
def _compute_loss(self, logits, labels, ignore_index):
|
def _compute_loss(self, logits, labels):
|
||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py
|
# Same behavior as modeling_bart.py
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||||
assert logits.shape[-1] == self.vocab_size
|
assert logits.shape[-1] == self.vocab_size
|
||||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
loss, nll_loss = label_smoothed_nll_loss(
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
|
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
||||||
)
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -137,14 +137,12 @@ class Seq2SeqTrainer(Trainer):
|
||||||
max_length=self.max_gen_length,
|
max_length=self.max_gen_length,
|
||||||
)
|
)
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
generated_tokens = self._pad_tensors_to_max_len(
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
|
||||||
generated_tokens, self.max_gen_length, self.pad_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
labels_out = inputs.get("labels")
|
labels_out = inputs.get("labels")
|
||||||
# Call forward again to get loss # TODO: avoidable?
|
# Call forward again to get loss # TODO: avoidable?
|
||||||
outputs = model(**inputs, use_cache=False)
|
outputs = model(**inputs, use_cache=False)
|
||||||
loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
|
loss = self._compute_loss(outputs[1], labels_out)
|
||||||
loss = loss.mean().item()
|
loss = loss.mean().item()
|
||||||
if self.args.prediction_loss_only:
|
if self.args.prediction_loss_only:
|
||||||
return (loss, None, None)
|
return (loss, None, None)
|
||||||
|
@ -152,11 +150,11 @@ class Seq2SeqTrainer(Trainer):
|
||||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||||
|
|
||||||
labels_out = labels_out.detach()
|
labels_out = labels_out.detach()
|
||||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
|
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
|
||||||
return (loss, logits.detach(), labels)
|
return (loss, logits.detach(), labels)
|
||||||
|
|
||||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||||
padded_tensor = pad_token_id * torch.ones(
|
padded_tensor = self.config.pad_token_id * torch.ones(
|
||||||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||||
)
|
)
|
||||||
padded_tensor[:, : tensor.shape[-1]] = tensor
|
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||||
|
|
|
@ -26,7 +26,7 @@ def test_finetune_trainer():
|
||||||
def test_finetune_trainer_slow():
|
def test_finetune_trainer_slow():
|
||||||
# TODO(SS): This will fail on devices with more than 1 GPU.
|
# TODO(SS): This will fail on devices with more than 1 GPU.
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
|
|
|
@ -269,7 +269,11 @@ class Seq2SeqDataCollator:
|
||||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.tpu_num_cores = tpu_num_cores
|
self.tpu_num_cores = tpu_num_cores
|
||||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
|
||||||
|
if data_args.src_lang is not None:
|
||||||
|
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
||||||
|
if data_args.tgt_lang is not None:
|
||||||
|
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
||||||
|
|
||||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||||
|
@ -310,14 +314,12 @@ class Seq2SeqDataCollator:
|
||||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
[x["src_texts"] for x in batch],
|
[x["src_texts"] for x in batch],
|
||||||
src_lang=self.data_args.src_lang,
|
|
||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
tgt_lang=self.data_args.tgt_lang,
|
|
||||||
max_length=self.data_args.max_source_length,
|
max_length=self.data_args.max_source_length,
|
||||||
max_target_length=self.data_args.max_target_length,
|
max_target_length=self.data_args.max_target_length,
|
||||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_prefix_space=self.add_prefix_space,
|
**self.dataset_kwargs,
|
||||||
)
|
)
|
||||||
return batch_encoding.data
|
return batch_encoding.data
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue