Trainer support for iterabledataset (#5834)
* Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Cleaner if nesting. * Added test for trainer and iterable dataset * Formatting for test * Fixed import when torch is available only. * Added require torch decorator to helper class * Moved dataset class inside unittest * Removed nested if and changed model in test * Checking torch availability for IterableDataset
This commit is contained in:
parent
82dd96cae7
commit
290b6e18ac
|
@ -230,9 +230,11 @@ class Trainer:
|
|||
"""
|
||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||
"""
|
||||
if self.train_dataset is None:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
train_sampler = None
|
||||
elif self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
if is_torch_tpu_available():
|
||||
elif is_torch_tpu_available():
|
||||
train_sampler = get_tpu_sampler(self.train_dataset)
|
||||
else:
|
||||
train_sampler = (
|
||||
|
@ -240,7 +242,6 @@ class Trainer:
|
|||
if self.args.local_rank == -1
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
|
||||
data_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
|
@ -264,7 +265,9 @@ class Trainer:
|
|||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
if is_torch_tpu_available():
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
elif is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
|
@ -291,7 +294,9 @@ class Trainer:
|
|||
test_dataset (obj:`Dataset`): The test dataset to use.
|
||||
"""
|
||||
# We use the same batch_size as for eval.
|
||||
if is_torch_tpu_available():
|
||||
if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
elif is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
|
@ -307,7 +312,6 @@ class Trainer:
|
|||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
return data_loader
|
||||
|
||||
def get_optimizers(
|
||||
|
|
|
@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch
|
|||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers import (
|
||||
Trainer,
|
||||
LineByLineTextDataset,
|
||||
AutoModelForSequenceClassification,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
TextDataset,
|
||||
Trainer,
|
||||
default_data_collator,
|
||||
)
|
||||
|
||||
|
||||
|
@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
data_collator(example)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class SampleIterableDataset(IterableDataset):
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
|
||||
def parse_file(self):
|
||||
f = open(self.file_path, "r")
|
||||
return f.readlines()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.parse_file())
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
def test_trainer_eval_mrpc(self):
|
||||
|
@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
|
||||
)
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
|
||||
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
|
|
Loading…
Reference in New Issue