Trainer support for iterabledataset (#5834)

* Don't pass sampler for iterable dataset

* Added check for test and eval dataloaders.

* Formatting

* Don't pass sampler for iterable dataset

* Added check for test and eval dataloaders.

* Formatting

* Cleaner if nesting.

* Added test for trainer and iterable dataset

* Formatting for test

* Fixed import when torch is available only.

* Added require torch decorator to helper class

* Moved dataset class inside unittest

* Removed nested if and changed model in test

* Checking torch availability for IterableDataset
This commit is contained in:
Pradhy729 2020-07-20 06:07:37 -07:00 committed by GitHub
parent 82dd96cae7
commit 290b6e18ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 9 deletions

16
src/transformers/trainer.py Normal file → Executable file
View File

@ -230,9 +230,11 @@ class Trainer:
""" """
Returns the training :class:`~torch.utils.data.DataLoader`. Returns the training :class:`~torch.utils.data.DataLoader`.
""" """
if self.train_dataset is None: if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
train_sampler = None
elif self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
if is_torch_tpu_available(): elif is_torch_tpu_available():
train_sampler = get_tpu_sampler(self.train_dataset) train_sampler = get_tpu_sampler(self.train_dataset)
else: else:
train_sampler = ( train_sampler = (
@ -240,7 +242,6 @@ class Trainer:
if self.args.local_rank == -1 if self.args.local_rank == -1
else DistributedSampler(self.train_dataset) else DistributedSampler(self.train_dataset)
) )
data_loader = DataLoader( data_loader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size=self.args.train_batch_size, batch_size=self.args.train_batch_size,
@ -264,7 +265,9 @@ class Trainer:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if is_torch_tpu_available(): if isinstance(eval_dataset, torch.utils.data.IterableDataset):
sampler = None
elif is_torch_tpu_available():
sampler = SequentialDistributedSampler( sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
@ -291,7 +294,9 @@ class Trainer:
test_dataset (obj:`Dataset`): The test dataset to use. test_dataset (obj:`Dataset`): The test dataset to use.
""" """
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
if is_torch_tpu_available(): if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
sampler = None
elif is_torch_tpu_available():
sampler = SequentialDistributedSampler( sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
@ -307,7 +312,6 @@ class Trainer:
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
return data_loader return data_loader
def get_optimizers( def get_optimizers(

31
tests/test_trainer.py Normal file → Executable file
View File

@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data import IterableDataset
from transformers import ( from transformers import (
Trainer,
LineByLineTextDataset,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
default_data_collator,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset,
TextDataset, TextDataset,
Trainer,
default_data_collator,
) )
@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_collator(example) data_collator(example)
if is_torch_available():
class SampleIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def parse_file(self):
f = open(self.file_path, "r")
return f.readlines()
def __iter__(self):
return iter(self.parse_file())
@require_torch @require_torch
class TrainerIntegrationTest(unittest.TestCase): class TrainerIntegrationTest(unittest.TestCase):
def test_trainer_eval_mrpc(self): def test_trainer_eval_mrpc(self):
@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase):
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence, tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
) )
self.assertEqual(len(dataset), 31) self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)