Trainer with Iterable Dataset (#7858)
* fix 5990 * accomodate iterable dataset without predefined length * set it as 1 use case: provide max_steps, and NO num_epochs * Is a merge of master and PR 5995 * fix trainer test under TF * fix only for torch * TF trainer untouched * trainer tests are skipped when no torch * address comments * fix quality checks * remove torch.dataset from test_trainer * unnecessary inheritance * RegressionDataset implements all needed methods __len__ and __getitem__ * fix quality checks * restore RegressionDataset * was wrongly under is_torch_available()
This commit is contained in:
parent
2422cda01b
commit
a09fe140c1
|
@ -16,7 +16,9 @@
|
|||
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
@ -283,6 +285,15 @@ class Trainer:
|
|||
FutureWarning,
|
||||
)
|
||||
|
||||
if args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
# Enforce rules on using datasets with no __len__
|
||||
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
|
||||
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
if is_datasets_available():
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(self.train_dataset, description="training")
|
||||
|
@ -361,7 +372,7 @@ class Trainer:
|
|||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
return get_tpu_sampler(self.train_dataset)
|
||||
|
@ -376,7 +387,7 @@ class Trainer:
|
|||
"""
|
||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
||||
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
|
||||
(adapted to distributed training if necessary) otherwise.
|
||||
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
@ -395,9 +406,7 @@ class Trainer:
|
|||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
elif self.args.local_rank != -1:
|
||||
return SequentialDistributedSampler(eval_dataset)
|
||||
|
@ -408,19 +417,18 @@ class Trainer:
|
|||
"""
|
||||
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||
sampler (adapted to distributed training if necessary) otherwise.
|
||||
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
||||
accepted by the ``model.forward()`` method are automatically removed.
|
||||
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||
"""
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
|
@ -438,17 +446,16 @@ class Trainer:
|
|||
"""
|
||||
Returns the test :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||
sampler (adapted to distributed training if necessary) otherwise.
|
||||
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||
"""
|
||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
if not isinstance(test_dataset, collections.abc.Sized):
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(test_dataset, description="test")
|
||||
test_sampler = self._get_eval_sampler(test_dataset)
|
||||
|
||||
|
@ -494,6 +501,8 @@ class Trainer:
|
|||
def num_examples(self, dataloader: DataLoader) -> int:
|
||||
"""
|
||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||
|
||||
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
|
||||
"""
|
||||
return len(dataloader.dataset)
|
||||
|
||||
|
@ -579,19 +588,32 @@ class Trainer:
|
|||
# Reinitializes optimizer and scheduler
|
||||
self.optimizer, self.lr_scheduler = None, None
|
||||
|
||||
# Keeping track whether we can can len() on the dataset or not
|
||||
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
|
||||
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
if self.args.max_steps > 0:
|
||||
max_steps = self.args.max_steps
|
||||
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
||||
self.args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
|
||||
# Setting up training control variables:
|
||||
# number of training epochs: num_train_epochs
|
||||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
if train_dataset_is_sized:
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
if self.args.max_steps > 0:
|
||||
max_steps = self.args.max_steps
|
||||
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
||||
self.args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
else:
|
||||
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
|
||||
num_train_epochs = math.ceil(self.args.num_train_epochs)
|
||||
else:
|
||||
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
num_train_epochs = int(np.ceil(num_train_epochs))
|
||||
# see __init__. max_steps is set when the dataset has no __len__
|
||||
max_steps = self.args.max_steps
|
||||
num_train_epochs = 1
|
||||
num_update_steps_per_epoch = max_steps
|
||||
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
self.state = TrainerState()
|
||||
|
@ -645,8 +667,15 @@ class Trainer:
|
|||
* self.args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
||||
)
|
||||
|
||||
num_examples = (
|
||||
self.num_examples(train_dataloader)
|
||||
if train_dataset_is_sized
|
||||
else total_train_batch_size * self.args.max_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||
|
@ -703,6 +732,7 @@ class Trainer:
|
|||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
|
||||
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
||||
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
@ -728,8 +758,8 @@ class Trainer:
|
|||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
len(epoch_iterator) <= self.args.gradient_accumulation_steps
|
||||
and (step + 1) == len(epoch_iterator)
|
||||
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||
and (step + 1) == steps_in_epoch
|
||||
):
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
@ -750,7 +780,7 @@ class Trainer:
|
|||
self.lr_scheduler.step()
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
|
@ -1207,11 +1237,15 @@ class Trainer:
|
|||
Args:
|
||||
eval_dataset (:obj:`Dataset`, `optional`):
|
||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
||||
columns not accepted by the ``model.forward()`` method are automatically removed.
|
||||
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
|
||||
the :obj:`__len__` method.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
||||
"""
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
|
||||
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
||||
|
@ -1234,7 +1268,7 @@ class Trainer:
|
|||
Args:
|
||||
test_dataset (:obj:`Dataset`):
|
||||
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||
|
||||
Returns:
|
||||
`NamedTuple`:
|
||||
|
@ -1245,6 +1279,9 @@ class Trainer:
|
|||
metrics (:obj:`Dict[str, float]`, `optional`):
|
||||
The potential dictionary of metrics (if the dataset contained labels).
|
||||
"""
|
||||
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
|
||||
return self.prediction_loop(test_dataloader, description="Prediction")
|
||||
|
@ -1264,6 +1301,8 @@ class Trainer:
|
|||
)
|
||||
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
||||
|
||||
if not isinstance(dataloader.dataset, collections.abc.Sized):
|
||||
raise ValueError("dataset must implement __len__")
|
||||
prediction_loss_only = (
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
|
|
|
@ -31,11 +31,14 @@ if is_torch_available():
|
|||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers import (
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorForLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
TextDataset,
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
|
@ -83,15 +86,16 @@ class RegressionModelConfig(PretrainedConfig):
|
|||
if is_torch_available():
|
||||
|
||||
class SampleIterableDataset(IterableDataset):
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
"""
|
||||
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
|
||||
"""
|
||||
|
||||
def parse_file(self):
|
||||
f = open(self.file_path, "r")
|
||||
return f.readlines()
|
||||
def __init__(self, file_path, tokenizer):
|
||||
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.parse_file())
|
||||
for i in range(len(self.ds)):
|
||||
yield self.ds[i]
|
||||
|
||||
class RegressionModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0, double_output=False):
|
||||
|
@ -540,13 +544,51 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
# Simulate Language Modeling with an IterableDataset, with no __len__ method
|
||||
# Pick-up a tiny model, so it works on CPU
|
||||
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
|
||||
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
||||
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||
trainer.train()
|
||||
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||
|
||||
# Exception if giving iterable dataset and no max_steps
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||
|
||||
# Exception if eval_dataset is iterable in __init__
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
_ = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Exception if predicting with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||
trainer.predict(train_dataset)
|
||||
|
||||
# Exception if evaluating with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||
trainer.evaluate(train_dataset)
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||
|
|
Loading…
Reference in New Issue