Fix a few countings (steps / epochs) in trainer_tf.py (#7175)

This commit is contained in:
Yih-Dar 2020-09-18 15:28:56 +02:00 committed by GitHub
parent ee9eae4e06
commit 3a03bab9db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 25 deletions

View File

@ -478,43 +478,58 @@ class TFTrainer:
self.gradient_accumulator.reset()
num_update_steps_per_epoch = self.num_train_examples / self.total_train_batch_size
# In fact, ``self.args.dataloader_drop_last`` has no effect in `trainer_tf.py`, because
# the dataset is repeated before being batched.
# It has the effect only when TPU is used which requires explicit tensor shape in order to make
# the gradient accumulation implementation work.
approx = math.floor if self.args.dataloader_drop_last else math.ceil
num_update_steps_per_epoch = approx(num_update_steps_per_epoch)
# At least one update for each epoch.
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
self.steps_per_epoch = num_update_steps_per_epoch
if self.args.max_steps > 0:
t_total = self.args.max_steps
self.steps_per_epoch = self.args.max_steps
epochs = (self.args.max_steps // self.steps_per_epoch) + int(
self.args.max_steps % self.steps_per_epoch > 0
)
else:
approx = math.floor if self.args.dataloader_drop_last else math.ceil
self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
t_total = self.steps_per_epoch * self.args.num_train_epochs
epochs = self.args.num_train_epochs
# Since ``self.args.num_train_epochs`` can be `float`, we make ``epochs`` be a `float` always.
epochs = float(epochs)
with self.args.strategy.scope():
self.create_optimizer_and_scheduler(num_training_steps=t_total)
iterations = self.optimizer.iterations
self.global_step = iterations.numpy()
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
iterations = self.optimizer.iterations
epochs_trained = 0
steps_trained_in_current_epoch = 0
if self.model.ckpt_manager.latest_checkpoint:
epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
self.num_train_examples // self.args.gradient_accumulation_steps
logger.info(
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
)
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
self.global_step = iterations.numpy()
epochs_trained = self.global_step // self.steps_per_epoch
steps_trained_in_current_epoch = self.global_step % self.steps_per_epoch
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
logger.info(
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
)
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
else:
epochs_trained = 1
tf.summary.experimental.set_step(iterations)
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
tf.summary.experimental.set_step(self.global_step)
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
@ -527,6 +542,7 @@ class TFTrainer:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples)
# TODO: We might want to print a more precise ``epochs`` if self.args.max_steps > 0 ?
logger.info(" Num Epochs = %d", epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(
@ -539,17 +555,23 @@ class TFTrainer:
self.train_loss = tf.keras.metrics.Sum()
start_time = datetime.datetime.now()
for epoch_iter in range(epochs_trained, int(epochs + 1)):
for epoch_iter in range(epochs_trained, int(epochs)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, batch in enumerate(train_ds):
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
self.distributed_training_steps(batch)
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter + (step + 1) / self.steps_per_epoch
training_loss = self.train_loss.result() / (step + 1)
if self.args.debug:
@ -566,13 +588,13 @@ class TFTrainer:
)
if (
self.global_step > 0
self.args.eval_steps > 0
and self.args.evaluate_during_training
and self.global_step % self.args.eval_steps == 0
):
self.evaluate()
if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs = {}
@ -582,16 +604,22 @@ class TFTrainer:
self.log(logs)
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
if self.args.max_steps > 0 and self.global_step >= t_total:
break
if self.global_step % self.steps_per_epoch == 0:
break
self.train_loss.reset_states()
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break
end_time = datetime.datetime.now()
logger.info("Training took: {}".format(str(end_time - start_time)))