Add predict step accumulation (#7767)
* Add eval_accumulation_step and clean distributed eval * Add TPU test * Add TPU stuff * Fix arg name * Fix Seq2SeqTrainer * Fix total_size * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Doc and add test to TPU * Add unit test * Adapt name Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
8feb0cc967
commit
a1d1b332d0
|
@ -19,3 +19,9 @@ Callbacks internals
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.trainer_callback.CallbackHandler
|
||||
|
||||
Distributed Evaluation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.trainer_pt_utils.DistributedTensorGatherer
|
||||
:members:
|
|
@ -174,7 +174,7 @@ class Seq2SeqTrainer(Trainer):
|
|||
# Call forward again to get loss # TODO: avoidable?
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
loss = self._compute_loss(outputs[1], labels_out)
|
||||
loss = loss.mean().item()
|
||||
loss = loss.mean().detach()
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
|
|
|
@ -81,3 +81,14 @@ class TorchXLAExamplesTests(unittest.TestCase):
|
|||
|
||||
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
|
||||
self.assertLess(end - start, 300)
|
||||
|
||||
def test_trainer_tpu(self):
|
||||
import xla_spawn
|
||||
|
||||
testargs = """
|
||||
transformers/tests/test_trainer_tpu.py
|
||||
--num_cores=8
|
||||
transformers/tests/test_trainer_tpu.py
|
||||
""".split()
|
||||
with patch.object(sys, "argv", testargs):
|
||||
xla_spawn.main()
|
||||
|
|
|
@ -59,6 +59,7 @@ from .trainer_callback import (
|
|||
TrainerState,
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
DistributedTensorGatherer,
|
||||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
|
@ -1266,18 +1267,29 @@ class Trainer:
|
|||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
else:
|
||||
model = self.model
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
num_examples = self.num_examples(dataloader)
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Batch size = %d", batch_size)
|
||||
eval_losses: List[float] = []
|
||||
preds: torch.Tensor = None
|
||||
label_ids: torch.Tensor = None
|
||||
losses_host: torch.Tensor = None
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
|
||||
world_size = 1
|
||||
if is_torch_tpu_available():
|
||||
world_size = xm.xrt_world_size()
|
||||
elif self.args.local_rank != -1:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
world_size = max(1, world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
|
||||
model.eval()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
|
@ -1288,55 +1300,46 @@ class Trainer:
|
|||
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
|
||||
for inputs in dataloader:
|
||||
for step, inputs in enumerate(dataloader):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
if loss is not None:
|
||||
eval_losses.extend([loss] * batch_size)
|
||||
losses = loss.repeat(batch_size)
|
||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
|
||||
if labels is not None:
|
||||
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
if self.args.local_rank != -1:
|
||||
# In distributed mode, concatenate all results from all nodes:
|
||||
if preds is not None:
|
||||
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
if label_ids is not None:
|
||||
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
preds = nested_xla_mesh_reduce(preds, "eval_preds")
|
||||
if label_ids is not None:
|
||||
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
|
||||
if eval_losses is not None:
|
||||
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Finally, turn the aggregated tensors into numpy arrays.
|
||||
if preds is not None:
|
||||
preds = nested_numpify(preds)
|
||||
if label_ids is not None:
|
||||
label_ids = nested_numpify(label_ids)
|
||||
eval_loss = eval_losses_gatherer.finalize()
|
||||
preds = preds_gatherer.finalize()
|
||||
label_ids = labels_gatherer.finalize()
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
if self.args.local_rank != -1:
|
||||
metrics["eval_loss"] = (
|
||||
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
metrics["eval_loss"] = np.mean(eval_losses)
|
||||
|
||||
if eval_loss is not None:
|
||||
metrics["eval_loss"] = eval_loss.mean().item()
|
||||
|
||||
# Prefix all keys with eval_
|
||||
for key in list(metrics.keys()):
|
||||
|
@ -1345,6 +1348,20 @@ class Trainer:
|
|||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
"""
|
||||
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
|
||||
concatenating them to `gathered`
|
||||
"""
|
||||
if tensors is None:
|
||||
return
|
||||
if is_torch_tpu_available():
|
||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||
elif self.args.local_rank != -1:
|
||||
tensors = distributed_concat(tensors)
|
||||
|
||||
return nested_numpify(tensors)
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
@ -1374,8 +1391,7 @@ class Trainer:
|
|||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
# The .mean() is to reduce in case of distributed training
|
||||
loss = outputs[0].mean().item()
|
||||
loss = outputs[0].mean().detach()
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
|
|
|
@ -21,11 +21,13 @@ import warnings
|
|||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
|
@ -33,6 +35,8 @@ if is_torch_tpu_available():
|
|||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def nested_concat(tensors, new_tensors, dim=0):
|
||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
||||
|
@ -41,7 +45,12 @@ def nested_concat(tensors, new_tensors, dim=0):
|
|||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
||||
return torch.cat((tensors, new_tensors), dim=dim)
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
return torch.cat((tensors, new_tensors), dim=dim)
|
||||
elif isinstance(tensors, np.ndarray):
|
||||
return np.concatenate((tensors, new_tensors), axis=dim)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
|
@ -177,3 +186,112 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
|||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
|
||||
|
||||
def nested_new_like(arrays, num_samples):
|
||||
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
|
||||
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
|
||||
|
||||
|
||||
def nested_truncate(tensors, limit):
|
||||
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_truncate(t, limit) for t in tensors)
|
||||
return tensors[:limit]
|
||||
|
||||
|
||||
class DistributedTensorGatherer:
|
||||
"""
|
||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU
|
||||
by chunks.
|
||||
|
||||
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on
|
||||
CPU at every step, our sampler will generate the following indices:
|
||||
|
||||
:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
|
||||
|
||||
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then
|
||||
process 0, 1 and 2 will be responsible of making predictions for the following samples:
|
||||
|
||||
- P0: :obj:`[0, 1, 2, 3, 4, 5]`
|
||||
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
|
||||
- P2: :obj:`[12, 13, 14, 15, 0, 1]`
|
||||
|
||||
The first batch treated on each process will be
|
||||
|
||||
- P0: :obj:`[0, 1]`
|
||||
- P1: :obj:`[6, 7]`
|
||||
- P2: :obj:`[12, 13]`
|
||||
|
||||
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor)
|
||||
corresponding to the following indices:
|
||||
|
||||
:obj:`[0, 1, 6, 7, 12, 13]`
|
||||
|
||||
If we directly concatenate our results without taking any precautions, the user will then get
|
||||
the predictions for the indices in this order at the end of the prediction loop:
|
||||
|
||||
:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
|
||||
|
||||
For some reason, that's not going to roll their boat. This class is there to solve that problem.
|
||||
|
||||
Args:
|
||||
|
||||
world_size (:obj:`int`):
|
||||
The number of processes used in the distributed training.
|
||||
num_samples (:obj:`int`):
|
||||
The number of samples in our dataset.
|
||||
make_multiple_of (:obj:`int`, `optional`):
|
||||
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
|
||||
(by adding samples).
|
||||
"""
|
||||
|
||||
def __init__(self, world_size, num_samples, make_multiple_of=None):
|
||||
self.world_size = world_size
|
||||
self.num_samples = num_samples
|
||||
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
||||
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
|
||||
self.process_length = self.total_samples // world_size
|
||||
self._storage = None
|
||||
self._offsets = None
|
||||
|
||||
def add_arrays(self, arrays):
|
||||
"""
|
||||
Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays
|
||||
passed so that if we're bound to get an OOM, it happens at the beginning.
|
||||
"""
|
||||
if arrays is None:
|
||||
return
|
||||
if self._storage is None:
|
||||
self._storage = nested_new_like(arrays, self.total_samples)
|
||||
self._offsets = list(range(0, self.total_samples, self.process_length))
|
||||
slice_len = self._nested_set_tensors(self._storage, arrays)
|
||||
for i in range(self.world_size):
|
||||
self._offsets[i] += slice_len
|
||||
|
||||
def _nested_set_tensors(self, storage, arrays):
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
for x, y in zip(storage, arrays):
|
||||
slice_len = self._nested_set_tensors(x, y)
|
||||
return slice_len
|
||||
assert (
|
||||
arrays.shape[0] % self.world_size == 0
|
||||
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
|
||||
|
||||
slice_len = arrays.shape[0] // self.world_size
|
||||
for i in range(self.world_size):
|
||||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
||||
return slice_len
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
|
||||
to get each process a dataset of the same length).
|
||||
"""
|
||||
if self._storage is None:
|
||||
return
|
||||
if self._offsets[0] != self.process_length:
|
||||
logger.warn("Not all data has been set. Are you sure you passed all values?")
|
||||
return nested_truncate(self._storage, self.num_samples)
|
||||
|
|
|
@ -67,7 +67,7 @@ class TrainingArguments:
|
|||
The batch size per GPU/TPU core/CPU for training.
|
||||
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8):
|
||||
The batch size per GPU/TPU core/CPU for evaluation.
|
||||
gradient_accumulation_steps: (:obj:`int`, `optional`, defaults to 1):
|
||||
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
|
||||
|
||||
.. warning::
|
||||
|
@ -75,6 +75,10 @@ class TrainingArguments:
|
|||
When using gradient accumulation, one step is counted as one step with backward pass. Therefore,
|
||||
logging, evaluation, save will be conducted every ``gradient_accumulation_steps * xxx_step`` training
|
||||
examples.
|
||||
eval_accumulation_steps (:obj:`int`, `optional`):
|
||||
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
|
||||
left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but
|
||||
requires more memory).
|
||||
learning_rate (:obj:`float`, `optional`, defaults to 5e-5):
|
||||
The initial learning rate for Adam.
|
||||
weight_decay (:obj:`float`, `optional`, defaults to 0):
|
||||
|
@ -225,6 +229,10 @@ class TrainingArguments:
|
|||
default=1,
|
||||
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
|
||||
)
|
||||
eval_accumulation_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
|
||||
)
|
||||
|
||||
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
|
|
|
@ -13,15 +13,14 @@
|
|||
# CUDA_VISIBLE_DEVICES=-1 python ./tests/test_trainer_distributed.py
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -101,4 +100,20 @@ if __name__ == "__main__":
|
|||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = 2
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
logger.info(metrics)
|
||||
if metrics["eval_success"] is not True:
|
||||
logger.error(metrics)
|
||||
exit(1)
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = None
|
||||
|
||||
logger.info("🔥 All distributed tests successful")
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# This test is meant to be run in on an instance with TPUs like this:
|
||||
#
|
||||
# python examples/xla_spawn.py --num_cores=8 tests/test_trainer_tpu.py
|
||||
#
|
||||
# Replace 8 with the number of TPU cores you have.
|
||||
#
|
||||
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length: int = 101):
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i) -> int:
|
||||
return i
|
||||
|
||||
class DummyDataCollator:
|
||||
def __call__(self, features):
|
||||
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some (unused) params otherwise DDP will complain.
|
||||
self.fc = nn.Linear(120, 80)
|
||||
|
||||
def forward(self, input_ids, labels=None):
|
||||
if labels is not None:
|
||||
return torch.tensor(0.0, device=input_ids.device), input_ids
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((TrainingArguments,))
|
||||
sys.argv += ["--output_dir", "./examples"]
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, tpu_num_cores: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.tpu_num_cores,
|
||||
)
|
||||
|
||||
# Essentially, what we want to verify in the distributed case is
|
||||
# that we get all samples back, in the right order.
|
||||
# (this is crucial for prediction for instance)
|
||||
for dataset_length in [1001, 256, 15]:
|
||||
dataset = DummyDataset(dataset_length)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
sequential = list(range(len(dataset)))
|
||||
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
|
||||
return {"success": success}
|
||||
|
||||
trainer = Trainer(
|
||||
model=DummyModel(),
|
||||
args=training_args,
|
||||
data_collator=DummyDataCollator(),
|
||||
eval_dataset=dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
metrics = trainer.evaluate()
|
||||
logger.info(metrics)
|
||||
if metrics["eval_success"] is not True:
|
||||
logger.error(metrics)
|
||||
exit(1)
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = 2
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
logger.info(metrics)
|
||||
if metrics["eval_success"] is not True:
|
||||
logger.error(metrics)
|
||||
exit(1)
|
||||
|
||||
p = trainer.predict(dataset)
|
||||
logger.info(p.metrics)
|
||||
if p.metrics["eval_success"] is not True:
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = None
|
||||
|
||||
logger.info("🔥 All distributed tests successful")
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.trainer_pt_utils import DistributedTensorGatherer
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerUtilsTest(unittest.TestCase):
|
||||
def test_distributed_tensor_gatherer(self):
|
||||
# Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
|
||||
world_size = 4
|
||||
num_samples = 21
|
||||
input_indices = [
|
||||
[0, 1, 6, 7, 12, 13, 18, 19],
|
||||
[2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
|
||||
[5, 11, 17, 2],
|
||||
]
|
||||
|
||||
predictions = np.random.normal(size=(num_samples, 13))
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices in input_indices:
|
||||
gatherer.add_arrays(predictions[indices])
|
||||
result = gatherer.finalize()
|
||||
self.assertTrue(np.array_equal(result, predictions))
|
||||
|
||||
# With nested tensors
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices in input_indices:
|
||||
gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]])
|
||||
result = gatherer.finalize()
|
||||
self.assertTrue(isinstance(result, list))
|
||||
self.assertTrue(len(result), 2)
|
||||
self.assertTrue(isinstance(result[1], list))
|
||||
self.assertTrue(len(result[1]), 2)
|
||||
self.assertTrue(np.array_equal(result[0], predictions))
|
||||
self.assertTrue(np.array_equal(result[1][0], predictions))
|
||||
self.assertTrue(np.array_equal(result[1][1], predictions))
|
Loading…
Reference in New Issue