From 290b6e18ac2848fc29bfa6a90bd702f642baea7e Mon Sep 17 00:00:00 2001 From: Pradhy729 <49659913+Pradhy729@users.noreply.github.com> Date: Mon, 20 Jul 2020 06:07:37 -0700 Subject: [PATCH] Trainer support for iterabledataset (#5834) * Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Cleaner if nesting. * Added test for trainer and iterable dataset * Formatting for test * Fixed import when torch is available only. * Added require torch decorator to helper class * Moved dataset class inside unittest * Removed nested if and changed model in test * Checking torch availability for IterableDataset --- src/transformers/trainer.py | 16 ++++++++++------ tests/test_trainer.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 9 deletions(-) mode change 100644 => 100755 src/transformers/trainer.py mode change 100644 => 100755 tests/test_trainer.py diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py old mode 100644 new mode 100755 index e21e6cdc46..23333f49ca --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -230,9 +230,11 @@ class Trainer: """ Returns the training :class:`~torch.utils.data.DataLoader`. """ - if self.train_dataset is None: + if isinstance(self.train_dataset, torch.utils.data.IterableDataset): + train_sampler = None + elif self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - if is_torch_tpu_available(): + elif is_torch_tpu_available(): train_sampler = get_tpu_sampler(self.train_dataset) else: train_sampler = ( @@ -240,7 +242,6 @@ class Trainer: if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) ) - data_loader = DataLoader( self.train_dataset, batch_size=self.args.train_batch_size, @@ -264,7 +265,9 @@ class Trainer: eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - if is_torch_tpu_available(): + if isinstance(eval_dataset, torch.utils.data.IterableDataset): + sampler = None + elif is_torch_tpu_available(): sampler = SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -291,7 +294,9 @@ class Trainer: test_dataset (obj:`Dataset`): The test dataset to use. """ # We use the same batch_size as for eval. - if is_torch_tpu_available(): + if isinstance(self.test_dataset, torch.utils.data.IterableDataset): + sampler = None + elif is_torch_tpu_available(): sampler = SequentialDistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -307,7 +312,6 @@ class Trainer: collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, ) - return data_loader def get_optimizers( diff --git a/tests/test_trainer.py b/tests/test_trainer.py old mode 100644 new mode 100755 index dd5a487be4..f5ad0bbeec --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch if is_torch_available(): import torch + from torch.utils.data import IterableDataset + from transformers import ( - Trainer, - LineByLineTextDataset, AutoModelForSequenceClassification, - default_data_collator, DataCollatorForLanguageModeling, DataCollatorForPermutationLanguageModeling, GlueDataset, GlueDataTrainingArguments, + LineByLineTextDataset, TextDataset, + Trainer, + default_data_collator, ) @@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase): data_collator(example) +if is_torch_available(): + + class SampleIterableDataset(IterableDataset): + def __init__(self, file_path): + self.file_path = file_path + + def parse_file(self): + f = open(self.file_path, "r") + return f.readlines() + + def __iter__(self): + return iter(self.parse_file()) + + @require_torch class TrainerIntegrationTest(unittest.TestCase): def test_trainer_eval_mrpc(self): @@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase): tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence, ) self.assertEqual(len(dataset), 31) + + def test_trainer_iterable_dataset(self): + MODEL_ID = "sshleifer/tiny-distilbert-base-cased" + model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) + train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT) + training_args = TrainingArguments(output_dir="./examples", no_cuda=True) + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset) + loader = trainer.get_train_dataloader() + self.assertIsInstance(loader, torch.utils.data.DataLoader)