Avoid looping when data exhausted (#14413)
* stop training when a finite IterableDataset is exhausted when using an iterable dataset num_epochs is set to sys.maxsize to make sure all data is consumed likewise we want to set max_steps high enough but still stop when all data is consumed (cherry picked from commit 6f0e1d6363153da9051e93acffe1cbab3a3f3b12) * fix typo flase -> false * add test for stopping training on exhausted finite iterable dataset * remove redundant gradient_accumulation_steps * run make style reformat training_args docstring
This commit is contained in:
parent
3e8d17e66d
commit
a33168aa78
|
@ -1287,6 +1287,7 @@ class Trainer:
|
|||
)
|
||||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
||||
|
||||
step = -1
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
|
@ -1386,6 +1387,13 @@ class Trainer:
|
|||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
if step < 0:
|
||||
logger.warning(
|
||||
f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
|
||||
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
|
||||
f" num_steps ({max_steps}) higher than the number of available samples."
|
||||
)
|
||||
self.control.should_training_stop = True
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
||||
|
|
|
@ -141,7 +141,8 @@ class TrainingArguments:
|
|||
the last epoch before stopping training).
|
||||
max_steps (:obj:`int`, `optional`, defaults to -1):
|
||||
If set to a positive number, the total number of training steps to perform. Overrides
|
||||
:obj:`num_train_epochs`.
|
||||
:obj:`num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching
|
||||
the set number of steps when all data is exhausted
|
||||
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
||||
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
||||
values.
|
||||
|
|
|
@ -172,6 +172,16 @@ if is_torch_available():
|
|||
for i in range(len(self.dataset)):
|
||||
yield self.dataset[i]
|
||||
|
||||
class FiniteIterableDataset(SampleIterableDataset):
|
||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||
super().__init__(a, b, length, seed, label_names)
|
||||
self.current_sample = 0
|
||||
|
||||
def __iter__(self):
|
||||
while self.current_sample < len(self.dataset):
|
||||
yield self.dataset[self.current_sample]
|
||||
self.current_sample += 1
|
||||
|
||||
class RegressionModel(nn.Module):
|
||||
def __init__(self, a=0, b=0, double_output=False):
|
||||
super().__init__()
|
||||
|
@ -856,7 +866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_flase(self):
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
eval_dataset = RegressionDataset()
|
||||
|
||||
|
@ -1058,6 +1068,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||
|
||||
def test_training_finite_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
|
||||
batch_size = 1
|
||||
num_samples = 10
|
||||
|
||||
available_steps = num_samples // batch_size
|
||||
|
||||
data = FiniteIterableDataset(length=num_samples)
|
||||
train_args = TrainingArguments(
|
||||
".",
|
||||
max_steps=available_steps + 1, # set a higher number than actually available
|
||||
per_device_train_batch_size=batch_size,
|
||||
)
|
||||
trainer = Trainer(model, train_dataset=data, args=train_args)
|
||||
with self.assertLogs("transformers.trainer", level="WARNING") as logs:
|
||||
trainer.train()
|
||||
self.assertIn(f"stopping training at step {available_steps}!", logs.output[0])
|
||||
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
|
|
Loading…
Reference in New Issue