Fixed trainer total_flos relaoding in distributed mode (#11383)

* Fixed trainer total_flos relaoding in distributed mode

* logging flos at the end of training
This commit is contained in:
Teven 2021-04-23 13:53:33 +02:00 committed by GitHub
parent 74e84f1fa6
commit 7bc86bea68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 13 deletions

View File

@ -426,9 +426,9 @@ class Trainer:
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
@ -1162,7 +1162,6 @@ class Trainer:
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
@ -1220,7 +1219,7 @@ class Trainer:
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += float(self.floating_point_ops(inputs))
self.current_flos += float(self.floating_point_ops(inputs))
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
@ -1321,9 +1320,8 @@ class Trainer:
)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
@ -1788,11 +1786,12 @@ class Trainer:
def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self._total_flos is not None:
if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else:
self.state.total_flos = self._total_flos
if self.args.local_rank != -1:
self.state.total_flos += distributed_broadcast_scalars([self.current_flos]).sum().item()
self.current_flos = 0
else:
self.state.total_flos = self.current_flos
self.current_flos = 0
def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
@ -1883,6 +1882,7 @@ class Trainer:
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug: