updating gather function with gather_for_metrics in run_wav2vec2_pretraining (#18877)

Co-authored-by: Arun Rajaram <arunrajaram@Aruns-MacBook-Pro.local>
This commit is contained in:
arun99481 2022-09-06 17:06:37 +05:30 committed by GitHub
parent 734b7e2a5a
commit 3b19c0317b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -596,7 +596,7 @@ def main():
# make sure that `num_losses` is summed for distributed training
# and average gradients over losses of all devices
if accelerator.state.num_processes > 1:
num_losses = accelerator.gather(num_losses).sum()
num_losses = accelerator.gather_for_metrics(num_losses).sum()
gradient_multiplier = accelerator.state.num_processes / num_losses
multiply_grads(model.module.parameters(), gradient_multiplier)
else:
@ -647,10 +647,10 @@ def main():
outputs.diversity_loss.detach()
if accelerator.state.num_processes > 1:
loss = accelerator.gather(loss).sum()
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
percent_masked = accelerator.gather(percent_masked).sum()
loss = accelerator.gather_for_metrics(loss).sum()
outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
percent_masked = accelerator.gather_for_metrics(percent_masked).sum()
train_logs = {
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
@ -713,7 +713,7 @@ def main():
# sum over devices in multi-processing
if accelerator.num_processes > 1:
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}