[Trainer] Add nan/inf logging filter (#13619)

* finish

* add test

* push

* remove unnecessary code

* up

* correct test

* Update src/transformers/training_args.py
This commit is contained in:
Patrick von Platen 2021-09-17 16:21:59 +02:00 committed by GitHub
parent eae7a96b7d
commit 1f9dcfc1ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 2 deletions

View File

@ -1297,9 +1297,16 @@ class Trainer:
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)
if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / 1 + (self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps

View File

@ -173,6 +173,16 @@ class TrainingArguments:
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
logging_nan_inf_filter (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to filter :obj:`nan` and :obj:`inf` losses for logging. If set to obj:`True` the loss of every step
that is :obj:`nan` or :obj:`inf` is filtered and the average loss of the current logging window is taken
instead.
.. note::
:obj:`logging_nan_inf_filter` only influences the logging of loss values, it does not change the
behavior the gradient is computed or applied to the model.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
@ -468,6 +478,7 @@ class TrainingArguments:
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: str = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},

View File

@ -15,6 +15,7 @@
import dataclasses
import gc
import math
import os
import random
import re
@ -528,6 +529,31 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=False)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
trainer.train()
log_history_no_filter = trainer.state.log_history
# Trainer with inf/nan filter
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=True)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
trainer.train()
log_history_filter = trainer.state.log_history
def is_any_loss_nan_or_inf(log_history):
losses = [l["loss"] for l in log_history[:-1]]
return any(math.isnan(x) for x in losses) or any(math.isinf(x) for x in losses)
self.assertTrue(is_any_loss_nan_or_inf(log_history_no_filter))
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
def test_train_and_eval_dataloaders(self):
n_gpu = max(1, torch.cuda.device_count())
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)