From 3b19c0317b6909e2d7f11b5053895ac55250e7da Mon Sep 17 00:00:00 2001 From: arun99481 Date: Tue, 6 Sep 2022 17:06:37 +0530 Subject: [PATCH] updating gather function with gather_for_metrics in run_wav2vec2_pretraining (#18877) Co-authored-by: Arun Rajaram --- .../run_wav2vec2_pretraining_no_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py index a3db215d08..0de1776df5 100755 --- a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py +++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py @@ -596,7 +596,7 @@ def main(): # make sure that `num_losses` is summed for distributed training # and average gradients over losses of all devices if accelerator.state.num_processes > 1: - num_losses = accelerator.gather(num_losses).sum() + num_losses = accelerator.gather_for_metrics(num_losses).sum() gradient_multiplier = accelerator.state.num_processes / num_losses multiply_grads(model.module.parameters(), gradient_multiplier) else: @@ -647,10 +647,10 @@ def main(): outputs.diversity_loss.detach() if accelerator.state.num_processes > 1: - loss = accelerator.gather(loss).sum() - outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum() - outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum() - percent_masked = accelerator.gather(percent_masked).sum() + loss = accelerator.gather_for_metrics(loss).sum() + outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum() + outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum() + percent_masked = accelerator.gather_for_metrics(percent_masked).sum() train_logs = { "loss": (loss * args.gradient_accumulation_steps) / num_losses, @@ -713,7 +713,7 @@ def main(): # sum over devices in multi-processing if accelerator.num_processes > 1: - val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()} + val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()} val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}