Avoid looping when data exhausted (#14413)

* stop training when a finite IterableDataset is exhausted

when using an iterable dataset num_epochs is set to
sys.maxsize to make sure all data is consumed
likewise we want to set max_steps high enough
but still stop when all data is consumed

(cherry picked from commit 6f0e1d6363153da9051e93acffe1cbab3a3f3b12)

* fix typo flase -> false

* add test for stopping training on exhausted finite iterable dataset

* remove redundant gradient_accumulation_steps

* run make style

reformat training_args docstring
This commit is contained in:
Valentin 2021-11-16 22:50:04 +01:00 committed by GitHub
parent 3e8d17e66d
commit a33168aa78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 2 deletions

View File

@ -1287,6 +1287,7 @@ class Trainer:
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
step = -1
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
@ -1386,6 +1387,13 @@ class Trainer:
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

View File

@ -141,7 +141,8 @@ class TrainingArguments:
the last epoch before stopping training).
max_steps (:obj:`int`, `optional`, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides
:obj:`num_train_epochs`.
:obj:`num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching
the set number of steps when all data is exhausted
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
values.

View File

@ -172,6 +172,16 @@ if is_torch_available():
for i in range(len(self.dataset)):
yield self.dataset[i]
class FiniteIterableDataset(SampleIterableDataset):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
super().__init__(a, b, length, seed, label_names)
self.current_sample = 0
def __iter__(self):
while self.current_sample < len(self.dataset):
yield self.dataset[self.current_sample]
self.current_sample += 1
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False):
super().__init__()
@ -856,7 +866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertAlmostEqual(b, b1, delta=1e-8)
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_flase(self):
def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()
@ -1058,6 +1068,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
def test_training_finite_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
batch_size = 1
num_samples = 10
available_steps = num_samples // batch_size
data = FiniteIterableDataset(length=num_samples)
train_args = TrainingArguments(
".",
max_steps=available_steps + 1, # set a higher number than actually available
per_device_train_batch_size=batch_size,
)
trainer = Trainer(model, train_dataset=data, args=train_args)
with self.assertLogs("transformers.trainer", level="WARNING") as logs:
trainer.train()
self.assertIn(f"stopping training at step {available_steps}!", logs.output[0])
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)