Fix pad across processes dim in trainer and not being able to set the timeout (#24775)

* dim, and rm copy

* Don't rm copy for now

* Oops

* pad index

* Should be a working test

* Tickle down ddp timeout

* Put fix back in now that testing locally is done

* Better comment specifying timeout

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2023-07-12 10:01:51 -04:00 committed by GitHub
parent 4f85aaa6c9
commit 0284285501
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 4 deletions

View File

@ -3131,9 +3131,9 @@ class Trainer:
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None: if labels is not None:
labels = self.accelerator.pad_across_processes(labels) labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode) inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
inputs_host = ( inputs_host = (
inputs_decode inputs_decode
@ -3141,7 +3141,7 @@ class Trainer:
else nested_concat(inputs_host, inputs_decode, padding_index=-100) else nested_concat(inputs_host, inputs_decode, padding_index=-100)
) )
if logits is not None: if logits is not None:
logits = self.accelerator.pad_across_processes(logits) logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.accelerator.gather_for_metrics((logits)) logits = self.accelerator.gather_for_metrics((logits))

View File

@ -1714,7 +1714,9 @@ class TrainingArguments:
del os.environ["ACCELERATE_USE_DEEPSPEED"] del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1 self._n_gpu = 1
else: else:
self.distributed_state = PartialState(backend=self.ddp_backend) self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1 self._n_gpu = 1
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled():
device = self.distributed_state.device device = self.distributed_state.device

View File

@ -49,6 +49,7 @@ from transformers.testing_utils import (
USER, USER,
CaptureLogger, CaptureLogger,
TestCasePlus, TestCasePlus,
execute_subprocess_async,
get_gpu_count, get_gpu_count,
get_tests_dir, get_tests_dir,
is_staging_test, is_staging_test,
@ -2098,6 +2099,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
@slow
@require_torch_multi_gpu
def test_end_to_end_example(self):
# Tests that `translation.py` will run without issues
script_path = os.path.abspath(
os.path.join(
os.path.dirname(__file__), "..", "..", "examples", "pytorch", "translation", "run_translation.py"
)
)
with tempfile.TemporaryDirectory() as tmpdir:
command = [
"accelerate",
"launch",
script_path,
"--model_name_or_path",
"t5-small",
"--per_device_train_batch_size",
"1",
"--output_dir",
tmpdir,
"--overwrite_output_dir",
"--do_train",
"--max_train_samples",
"64",
"--num_train_epochs",
"1",
"--dataset_name",
"wmt16",
"--dataset_config",
"ro-en",
"--source_lang",
"en",
"--target_lang",
"ro",
"--do_predict",
"--max_predict_samples",
"64",
"--predict_with_generate",
"--ddp_timeout",
"60",
]
execute_subprocess_async(command)
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
@require_torch @require_torch
@is_staging_test @is_staging_test