updating gather function with gather_for_metrics in run_wav2vec2_pretraining (#18877)
Co-authored-by: Arun Rajaram <arunrajaram@Aruns-MacBook-Pro.local>
This commit is contained in:
parent
734b7e2a5a
commit
3b19c0317b
|
@ -596,7 +596,7 @@ def main():
|
|||
# make sure that `num_losses` is summed for distributed training
|
||||
# and average gradients over losses of all devices
|
||||
if accelerator.state.num_processes > 1:
|
||||
num_losses = accelerator.gather(num_losses).sum()
|
||||
num_losses = accelerator.gather_for_metrics(num_losses).sum()
|
||||
gradient_multiplier = accelerator.state.num_processes / num_losses
|
||||
multiply_grads(model.module.parameters(), gradient_multiplier)
|
||||
else:
|
||||
|
@ -647,10 +647,10 @@ def main():
|
|||
outputs.diversity_loss.detach()
|
||||
|
||||
if accelerator.state.num_processes > 1:
|
||||
loss = accelerator.gather(loss).sum()
|
||||
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
|
||||
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
|
||||
percent_masked = accelerator.gather(percent_masked).sum()
|
||||
loss = accelerator.gather_for_metrics(loss).sum()
|
||||
outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
|
||||
outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
|
||||
percent_masked = accelerator.gather_for_metrics(percent_masked).sum()
|
||||
|
||||
train_logs = {
|
||||
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
|
||||
|
@ -713,7 +713,7 @@ def main():
|
|||
|
||||
# sum over devices in multi-processing
|
||||
if accelerator.num_processes > 1:
|
||||
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
|
||||
val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}
|
||||
|
||||
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}
|
||||
|
||||
|
|
Loading…
Reference in New Issue