Trainer support for iterabledataset (#5834)
* Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Cleaner if nesting. * Added test for trainer and iterable dataset * Formatting for test * Fixed import when torch is available only. * Added require torch decorator to helper class * Moved dataset class inside unittest * Removed nested if and changed model in test * Checking torch availability for IterableDataset
This commit is contained in:
parent
82dd96cae7
commit
290b6e18ac
|
@ -230,9 +230,11 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||||
"""
|
"""
|
||||||
if self.train_dataset is None:
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
train_sampler = None
|
||||||
|
elif self.train_dataset is None:
|
||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
if is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
train_sampler = get_tpu_sampler(self.train_dataset)
|
train_sampler = get_tpu_sampler(self.train_dataset)
|
||||||
else:
|
else:
|
||||||
train_sampler = (
|
train_sampler = (
|
||||||
|
@ -240,7 +242,6 @@ class Trainer:
|
||||||
if self.args.local_rank == -1
|
if self.args.local_rank == -1
|
||||||
else DistributedSampler(self.train_dataset)
|
else DistributedSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=self.args.train_batch_size,
|
batch_size=self.args.train_batch_size,
|
||||||
|
@ -264,7 +265,9 @@ class Trainer:
|
||||||
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
|
sampler = None
|
||||||
|
elif is_torch_tpu_available():
|
||||||
sampler = SequentialDistributedSampler(
|
sampler = SequentialDistributedSampler(
|
||||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||||
)
|
)
|
||||||
|
@ -291,7 +294,9 @@ class Trainer:
|
||||||
test_dataset (obj:`Dataset`): The test dataset to use.
|
test_dataset (obj:`Dataset`): The test dataset to use.
|
||||||
"""
|
"""
|
||||||
# We use the same batch_size as for eval.
|
# We use the same batch_size as for eval.
|
||||||
if is_torch_tpu_available():
|
if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
|
||||||
|
sampler = None
|
||||||
|
elif is_torch_tpu_available():
|
||||||
sampler = SequentialDistributedSampler(
|
sampler = SequentialDistributedSampler(
|
||||||
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||||
)
|
)
|
||||||
|
@ -307,7 +312,6 @@ class Trainer:
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
def get_optimizers(
|
def get_optimizers(
|
||||||
|
|
|
@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
Trainer,
|
|
||||||
LineByLineTextDataset,
|
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
default_data_collator,
|
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
DataCollatorForPermutationLanguageModeling,
|
DataCollatorForPermutationLanguageModeling,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
|
LineByLineTextDataset,
|
||||||
TextDataset,
|
TextDataset,
|
||||||
|
Trainer,
|
||||||
|
default_data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||||
data_collator(example)
|
data_collator(example)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
|
||||||
|
class SampleIterableDataset(IterableDataset):
|
||||||
|
def __init__(self, file_path):
|
||||||
|
self.file_path = file_path
|
||||||
|
|
||||||
|
def parse_file(self):
|
||||||
|
f = open(self.file_path, "r")
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.parse_file())
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TrainerIntegrationTest(unittest.TestCase):
|
class TrainerIntegrationTest(unittest.TestCase):
|
||||||
def test_trainer_eval_mrpc(self):
|
def test_trainer_eval_mrpc(self):
|
||||||
|
@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||||
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
|
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
|
||||||
)
|
)
|
||||||
self.assertEqual(len(dataset), 31)
|
self.assertEqual(len(dataset), 31)
|
||||||
|
|
||||||
|
def test_trainer_iterable_dataset(self):
|
||||||
|
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
|
||||||
|
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||||
|
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
||||||
|
loader = trainer.get_train_dataloader()
|
||||||
|
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||||
|
|
Loading…
Reference in New Issue