Trainer with Iterable Dataset (#7858)

* fix 5990

* accomodate iterable dataset without predefined length
* set it as 1 use case: provide max_steps, and NO num_epochs
* Is a merge of master and PR 5995

* fix trainer test under TF

* fix only for torch
* TF trainer untouched
* trainer tests are skipped when no torch

* address comments

* fix quality checks

* remove torch.dataset from test_trainer

* unnecessary inheritance
* RegressionDataset implements all needed methods __len__ and __getitem__

* fix quality checks

* restore RegressionDataset

* was wrongly under is_torch_available()
This commit is contained in:
Julien Rossi 2020-10-19 17:57:39 +02:00 committed by GitHub
parent 2422cda01b
commit a09fe140c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 42 deletions

View File

@ -16,7 +16,9 @@
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import collections
import inspect
import math
import os
import re
import shutil
@ -283,6 +285,15 @@ class Trainer:
FutureWarning,
)
if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training")
@ -361,7 +372,7 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
@ -376,7 +387,7 @@ class Trainer:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
(adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
@ -395,9 +406,7 @@ class Trainer:
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
if is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
@ -408,19 +417,18 @@ class Trainer:
"""
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed.
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset)
@ -438,17 +446,16 @@ class Trainer:
"""
Returns the test :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
if not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test")
test_sampler = self._get_eval_sampler(test_dataset)
@ -494,6 +501,8 @@ class Trainer:
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
"""
return len(dataloader.dataset)
@ -579,19 +588,32 @@ class Trainer:
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
else:
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
num_train_epochs = int(np.ceil(num_train_epochs))
# see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps
num_train_epochs = 1
num_update_steps_per_epoch = max_steps
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
@ -645,8 +667,15 @@ class Trainer:
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num examples = %d", num_examples)
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
@ -703,6 +732,7 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):
@ -728,8 +758,8 @@ class Trainer:
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(self.optimizer)
@ -750,7 +780,7 @@ class Trainer:
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
@ -1207,11 +1237,15 @@ class Trainer:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed.
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
the :obj:`__len__` method.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self.prediction_loop(eval_dataloader, description="Evaluation")
@ -1234,7 +1268,7 @@ class Trainer:
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
Returns:
`NamedTuple`:
@ -1245,6 +1279,9 @@ class Trainer:
metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels).
"""
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction")
@ -1264,6 +1301,8 @@ class Trainer:
)
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)

62
tests/test_trainer.py Executable file → Normal file
View File

@ -31,11 +31,14 @@ if is_torch_available():
from torch.utils.data import IterableDataset
from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Trainer,
TrainerState,
)
@ -83,15 +86,16 @@ class RegressionModelConfig(PretrainedConfig):
if is_torch_available():
class SampleIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
"""
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""
def parse_file(self):
f = open(self.file_path, "r")
return f.readlines()
def __init__(self, file_path, tokenizer):
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
def __iter__(self):
return iter(self.parse_file())
for i in range(len(self.ds)):
yield self.ds[i]
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False):
@ -540,13 +544,51 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
trainer.train()
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)
def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.