Add predict step accumulation (#7767)

* Add eval_accumulation_step and clean distributed eval

* Add TPU test

* Add TPU stuff

* Fix arg name

* Fix Seq2SeqTrainer

* Fix total_size

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Doc and add test to TPU

* Add unit test

* Adapt name

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger 2020-10-14 11:41:45 -04:00 committed by GitHub
parent 8feb0cc967
commit a1d1b332d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 413 additions and 47 deletions

View File

@ -19,3 +19,9 @@ Callbacks internals
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.trainer_callback.CallbackHandler
Distributed Evaluation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.trainer_pt_utils.DistributedTensorGatherer
:members:

View File

@ -174,7 +174,7 @@ class Seq2SeqTrainer(Trainer):
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().item()
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

View File

@ -81,3 +81,14 @@ class TorchXLAExamplesTests(unittest.TestCase):
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
self.assertLess(end - start, 300)
def test_trainer_tpu(self):
import xla_spawn
testargs = """
transformers/tests/test_trainer_tpu.py
--num_cores=8
transformers/tests/test_trainer_tpu.py
""".split()
with patch.object(sys, "argv", testargs):
xla_spawn.main()

View File

@ -59,6 +59,7 @@ from .trainer_callback import (
TrainerState,
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
@ -1266,18 +1267,29 @@ class Trainer:
# multi-gpu eval
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
else:
model = self.model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: torch.Tensor = None
label_ids: torch.Tensor = None
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = 1
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = torch.distributed.get_world_size()
world_size = max(1, world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
model.eval()
if is_torch_tpu_available():
@ -1288,55 +1300,46 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader
for inputs in dataloader:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None:
eval_losses.extend([loss] * batch_size)
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = nested_xla_mesh_reduce(preds, "eval_preds")
if label_ids is not None:
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = nested_numpify(preds)
if label_ids is not None:
label_ids = nested_numpify(label_ids)
eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize()
label_ids = labels_gatherer.finalize()
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
if len(eval_losses) > 0:
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)
if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item()
# Prefix all keys with eval_
for key in list(metrics.keys()):
@ -1345,6 +1348,20 @@ class Trainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _gather_and_numpify(self, tensors, name):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
if tensors is None:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
@ -1374,8 +1391,7 @@ class Trainer:
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
# The .mean() is to reduce in case of distributed training
loss = outputs[0].mean().item()
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None

View File

@ -21,11 +21,13 @@ import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available
from .utils import logging
if is_torch_tpu_available():
@ -33,6 +35,8 @@ if is_torch_tpu_available():
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
logger = logging.get_logger(__name__)
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
@ -41,7 +45,12 @@ def nested_concat(tensors, new_tensors, dim=0):
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, torch.Tensor):
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, np.ndarray):
return np.concatenate((tensors, new_tensors), axis=dim)
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
def nested_numpify(tensors):
@ -177,3 +186,112 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
def nested_new_like(arrays, num_samples):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
def nested_truncate(tensors, limit):
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_truncate(t, limit) for t in tensors)
return tensors[:limit]
class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU
by chunks.
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on
CPU at every step, our sampler will generate the following indices:
:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then
process 0, 1 and 2 will be responsible of making predictions for the following samples:
- P0: :obj:`[0, 1, 2, 3, 4, 5]`
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
- P2: :obj:`[12, 13, 14, 15, 0, 1]`
The first batch treated on each process will be
- P0: :obj:`[0, 1]`
- P1: :obj:`[6, 7]`
- P2: :obj:`[12, 13]`
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor)
corresponding to the following indices:
:obj:`[0, 1, 6, 7, 12, 13]`
If we directly concatenate our results without taking any precautions, the user will then get
the predictions for the indices in this order at the end of the prediction loop:
:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
For some reason, that's not going to roll their boat. This class is there to solve that problem.
Args:
world_size (:obj:`int`):
The number of processes used in the distributed training.
num_samples (:obj:`int`):
The number of samples in our dataset.
make_multiple_of (:obj:`int`, `optional`):
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
(by adding samples).
"""
def __init__(self, world_size, num_samples, make_multiple_of=None):
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
self.process_length = self.total_samples // world_size
self._storage = None
self._offsets = None
def add_arrays(self, arrays):
"""
Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays
passed so that if we're bound to get an OOM, it happens at the beginning.
"""
if arrays is None:
return
if self._storage is None:
self._storage = nested_new_like(arrays, self.total_samples)
self._offsets = list(range(0, self.total_samples, self.process_length))
slice_len = self._nested_set_tensors(self._storage, arrays)
for i in range(self.world_size):
self._offsets[i] += slice_len
def _nested_set_tensors(self, storage, arrays):
if isinstance(arrays, (list, tuple)):
for x, y in zip(storage, arrays):
slice_len = self._nested_set_tensors(x, y)
return slice_len
assert (
arrays.shape[0] % self.world_size == 0
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
slice_len = arrays.shape[0] // self.world_size
for i in range(self.world_size):
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
return slice_len
def finalize(self):
"""
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
to get each process a dataset of the same length).
"""
if self._storage is None:
return
if self._offsets[0] != self.process_length:
logger.warn("Not all data has been set. Are you sure you passed all values?")
return nested_truncate(self._storage, self.num_samples)

View File

@ -67,7 +67,7 @@ class TrainingArguments:
The batch size per GPU/TPU core/CPU for training.
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8):
The batch size per GPU/TPU core/CPU for evaluation.
gradient_accumulation_steps: (:obj:`int`, `optional`, defaults to 1):
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
.. warning::
@ -75,6 +75,10 @@ class TrainingArguments:
When using gradient accumulation, one step is counted as one step with backward pass. Therefore,
logging, evaluation, save will be conducted every ``gradient_accumulation_steps * xxx_step`` training
examples.
eval_accumulation_steps (:obj:`int`, `optional`):
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but
requires more memory).
learning_rate (:obj:`float`, `optional`, defaults to 5e-5):
The initial learning rate for Adam.
weight_decay (:obj:`float`, `optional`, defaults to 0):
@ -225,6 +229,10 @@ class TrainingArguments:
default=1,
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
)
eval_accumulation_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2018 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import os
import tempfile

View File

@ -13,15 +13,14 @@
# CUDA_VISIBLE_DEVICES=-1 python ./tests/test_trainer_distributed.py
#
import logging
import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
from transformers.utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
if is_torch_available():
@ -101,4 +100,20 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = 2
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = None
logger.info("🔥 All distributed tests successful")

119
tests/test_trainer_tpu.py Normal file
View File

@ -0,0 +1,119 @@
# This test is meant to be run in on an instance with TPUs like this:
#
# python examples/xla_spawn.py --num_cores=8 tests/test_trainer_tpu.py
#
# Replace 8 with the number of TPU cores you have.
#
import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
from transformers.utils import logging
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import Trainer
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
return i
class DummyDataCollator:
def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
# Add some (unused) params otherwise DDP will complain.
self.fc = nn.Linear(120, 80)
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
def main():
parser = HfArgumentParser((TrainingArguments,))
sys.argv += ["--output_dir", "./examples"]
training_args = parser.parse_args_into_dataclasses()[0]
logger.warning(
"Process rank: %s, device: %s, tpu_num_cores: %s",
training_args.local_rank,
training_args.device,
training_args.tpu_num_cores,
)
# Essentially, what we want to verify in the distributed case is
# that we get all samples back, in the right order.
# (this is crucial for prediction for instance)
for dataset_length in [1001, 256, 15]:
dataset = DummyDataset(dataset_length)
def compute_metrics(p: EvalPrediction) -> Dict:
sequential = list(range(len(dataset)))
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
return {"success": success}
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = 2
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = None
logger.info("🔥 All distributed tests successful")
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
# coding=utf-8
# Copyright 2018 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_torch_available
from transformers.testing_utils import require_torch
if is_torch_available():
from transformers.trainer_pt_utils import DistributedTensorGatherer
@require_torch
class TrainerUtilsTest(unittest.TestCase):
def test_distributed_tensor_gatherer(self):
# Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
world_size = 4
num_samples = 21
input_indices = [
[0, 1, 6, 7, 12, 13, 18, 19],
[2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
[5, 11, 17, 2],
]
predictions = np.random.normal(size=(num_samples, 13))
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
for indices in input_indices:
gatherer.add_arrays(predictions[indices])
result = gatherer.finalize()
self.assertTrue(np.array_equal(result, predictions))
# With nested tensors
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
for indices in input_indices:
gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]])
result = gatherer.finalize()
self.assertTrue(isinstance(result, list))
self.assertTrue(len(result), 2)
self.assertTrue(isinstance(result[1], list))
self.assertTrue(len(result[1]), 2)
self.assertTrue(np.array_equal(result[0], predictions))
self.assertTrue(np.array_equal(result[1][0], predictions))
self.assertTrue(np.array_equal(result[1][1], predictions))