[debug] DebugUnderflowOverflow doesn't work with DP (#12816)

This commit is contained in:
Stas Bekman 2021-07-21 09:36:02 -07:00 committed by GitHub
parent ac3cb660ca
commit cf0755aa6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 4 deletions

View File

@ -24,7 +24,11 @@ Underflow and Overflow Detection
.. note::
This feature can be used with any ``nn.Module``-based model
For multi-GPU training it requires DDP (``torch.distributed.launch``).
.. note::
This feature can be used with any ``nn.Module``-based model.
If you start getting ``loss=NaN`` or the model inhibits some other abnormal behavior due to ``inf`` or ``nan`` in
activations or weights one needs to discover where the first underflow or overflow happens and what led to it. Luckily

View File

@ -1114,7 +1114,14 @@ class Trainer:
num_train_samples = args.max_steps * total_train_batch_size
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if args.deepspeed:

View File

@ -420,7 +420,7 @@ class TrainerMemoryTracker:
self.cur_stage = None
def update_metrics(self, stage, metrics):
"""stop tracking for the passed stage"""
"""updates the metrics"""
if self.skip_memory_metrics:
return
@ -442,7 +442,7 @@ class TrainerMemoryTracker:
metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
def stop_and_update_metrics(self, metrics=None):
"""combine stop + update in one call for simpler code"""
"""combine stop and metrics update in one call for simpler code"""
if self.skip_memory_metrics:
return