Fix pad across processes dim in trainer and not being able to set the timeout (#24775)
* dim, and rm copy * Don't rm copy for now * Oops * pad index * Should be a working test * Tickle down ddp timeout * Put fix back in now that testing locally is done * Better comment specifying timeout Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
4f85aaa6c9
commit
0284285501
|
@ -3131,9 +3131,9 @@ class Trainer:
|
||||||
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
|
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
|
||||||
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self.accelerator.pad_across_processes(labels)
|
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||||
if inputs_decode is not None:
|
if inputs_decode is not None:
|
||||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode)
|
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||||
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
|
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
|
||||||
inputs_host = (
|
inputs_host = (
|
||||||
inputs_decode
|
inputs_decode
|
||||||
|
@ -3141,7 +3141,7 @@ class Trainer:
|
||||||
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
|
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
|
||||||
)
|
)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = self.accelerator.pad_across_processes(logits)
|
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||||
if self.preprocess_logits_for_metrics is not None:
|
if self.preprocess_logits_for_metrics is not None:
|
||||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||||
logits = self.accelerator.gather_for_metrics((logits))
|
logits = self.accelerator.gather_for_metrics((logits))
|
||||||
|
|
|
@ -1714,7 +1714,9 @@ class TrainingArguments:
|
||||||
del os.environ["ACCELERATE_USE_DEEPSPEED"]
|
del os.environ["ACCELERATE_USE_DEEPSPEED"]
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
else:
|
else:
|
||||||
self.distributed_state = PartialState(backend=self.ddp_backend)
|
self.distributed_state = PartialState(
|
||||||
|
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
|
||||||
|
)
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
if not is_sagemaker_mp_enabled():
|
if not is_sagemaker_mp_enabled():
|
||||||
device = self.distributed_state.device
|
device = self.distributed_state.device
|
||||||
|
|
|
@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
||||||
USER,
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
execute_subprocess_async,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
|
@ -2098,6 +2099,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_end_to_end_example(self):
|
||||||
|
# Tests that `translation.py` will run without issues
|
||||||
|
script_path = os.path.abspath(
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "examples", "pytorch", "translation", "run_translation.py"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
command = [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
script_path,
|
||||||
|
"--model_name_or_path",
|
||||||
|
"t5-small",
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
"1",
|
||||||
|
"--output_dir",
|
||||||
|
tmpdir,
|
||||||
|
"--overwrite_output_dir",
|
||||||
|
"--do_train",
|
||||||
|
"--max_train_samples",
|
||||||
|
"64",
|
||||||
|
"--num_train_epochs",
|
||||||
|
"1",
|
||||||
|
"--dataset_name",
|
||||||
|
"wmt16",
|
||||||
|
"--dataset_config",
|
||||||
|
"ro-en",
|
||||||
|
"--source_lang",
|
||||||
|
"en",
|
||||||
|
"--target_lang",
|
||||||
|
"ro",
|
||||||
|
"--do_predict",
|
||||||
|
"--max_predict_samples",
|
||||||
|
"64",
|
||||||
|
"--predict_with_generate",
|
||||||
|
"--ddp_timeout",
|
||||||
|
"60",
|
||||||
|
]
|
||||||
|
execute_subprocess_async(command)
|
||||||
|
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|
Loading…
Reference in New Issue