Fix a few countings (steps / epochs) in trainer_tf.py (#7175)
This commit is contained in:
parent
ee9eae4e06
commit
3a03bab9db
|
@ -478,43 +478,58 @@ class TFTrainer:
|
||||||
|
|
||||||
self.gradient_accumulator.reset()
|
self.gradient_accumulator.reset()
|
||||||
|
|
||||||
|
num_update_steps_per_epoch = self.num_train_examples / self.total_train_batch_size
|
||||||
|
|
||||||
|
# In fact, ``self.args.dataloader_drop_last`` has no effect in `trainer_tf.py`, because
|
||||||
|
# the dataset is repeated before being batched.
|
||||||
|
# It has the effect only when TPU is used which requires explicit tensor shape in order to make
|
||||||
|
# the gradient accumulation implementation work.
|
||||||
|
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
||||||
|
num_update_steps_per_epoch = approx(num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# At least one update for each epoch.
|
||||||
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||||
|
self.steps_per_epoch = num_update_steps_per_epoch
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
t_total = self.args.max_steps
|
t_total = self.args.max_steps
|
||||||
self.steps_per_epoch = self.args.max_steps
|
epochs = (self.args.max_steps // self.steps_per_epoch) + int(
|
||||||
|
self.args.max_steps % self.steps_per_epoch > 0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
|
||||||
self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
|
|
||||||
t_total = self.steps_per_epoch * self.args.num_train_epochs
|
t_total = self.steps_per_epoch * self.args.num_train_epochs
|
||||||
|
epochs = self.args.num_train_epochs
|
||||||
|
|
||||||
|
# Since ``self.args.num_train_epochs`` can be `float`, we make ``epochs`` be a `float` always.
|
||||||
|
epochs = float(epochs)
|
||||||
|
|
||||||
with self.args.strategy.scope():
|
with self.args.strategy.scope():
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
||||||
iterations = self.optimizer.iterations
|
|
||||||
self.global_step = iterations.numpy()
|
|
||||||
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
||||||
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
||||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
||||||
|
|
||||||
|
iterations = self.optimizer.iterations
|
||||||
|
epochs_trained = 0
|
||||||
|
steps_trained_in_current_epoch = 0
|
||||||
if self.model.ckpt_manager.latest_checkpoint:
|
if self.model.ckpt_manager.latest_checkpoint:
|
||||||
epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
|
|
||||||
steps_trained_in_current_epoch = self.global_step % (
|
logger.info(
|
||||||
self.num_train_examples // self.args.gradient_accumulation_steps
|
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
|
||||||
)
|
)
|
||||||
|
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
||||||
|
|
||||||
|
self.global_step = iterations.numpy()
|
||||||
|
|
||||||
|
epochs_trained = self.global_step // self.steps_per_epoch
|
||||||
|
steps_trained_in_current_epoch = self.global_step % self.steps_per_epoch
|
||||||
|
|
||||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
logger.info(" Continuing training from global step %d", self.global_step)
|
logger.info(" Continuing training from global step %d", self.global_step)
|
||||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
logger.info(
|
|
||||||
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
tf.summary.experimental.set_step(self.global_step)
|
||||||
else:
|
|
||||||
epochs_trained = 1
|
|
||||||
|
|
||||||
tf.summary.experimental.set_step(iterations)
|
|
||||||
|
|
||||||
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
|
||||||
|
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
||||||
|
@ -527,6 +542,7 @@ class TFTrainer:
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", self.num_train_examples)
|
logger.info(" Num examples = %d", self.num_train_examples)
|
||||||
|
# TODO: We might want to print a more precise ``epochs`` if self.args.max_steps > 0 ?
|
||||||
logger.info(" Num Epochs = %d", epochs)
|
logger.info(" Num Epochs = %d", epochs)
|
||||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -539,17 +555,23 @@ class TFTrainer:
|
||||||
self.train_loss = tf.keras.metrics.Sum()
|
self.train_loss = tf.keras.metrics.Sum()
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
for epoch_iter in range(epochs_trained, int(epochs)):
|
||||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
for step, batch in enumerate(train_ds):
|
for step, batch in enumerate(train_ds):
|
||||||
self.global_step = iterations.numpy()
|
|
||||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
|
# Skip past any already trained steps if resuming training
|
||||||
|
if steps_trained_in_current_epoch > 0:
|
||||||
|
steps_trained_in_current_epoch -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
self.distributed_training_steps(batch)
|
self.distributed_training_steps(batch)
|
||||||
|
|
||||||
|
self.global_step = iterations.numpy()
|
||||||
|
self.epoch_logging = epoch_iter + (step + 1) / self.steps_per_epoch
|
||||||
|
|
||||||
training_loss = self.train_loss.result() / (step + 1)
|
training_loss = self.train_loss.result() / (step + 1)
|
||||||
|
|
||||||
if self.args.debug:
|
if self.args.debug:
|
||||||
|
@ -566,13 +588,13 @@ class TFTrainer:
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.global_step > 0
|
self.args.eval_steps > 0
|
||||||
and self.args.evaluate_during_training
|
and self.args.evaluate_during_training
|
||||||
and self.global_step % self.args.eval_steps == 0
|
and self.global_step % self.args.eval_steps == 0
|
||||||
):
|
):
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
|
|
||||||
if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
|
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
|
||||||
self.global_step == 1 and self.args.logging_first_step
|
self.global_step == 1 and self.args.logging_first_step
|
||||||
):
|
):
|
||||||
logs = {}
|
logs = {}
|
||||||
|
@ -582,16 +604,22 @@ class TFTrainer:
|
||||||
|
|
||||||
self.log(logs)
|
self.log(logs)
|
||||||
|
|
||||||
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
|
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||||
ckpt_save_path = self.model.ckpt_manager.save()
|
ckpt_save_path = self.model.ckpt_manager.save()
|
||||||
|
|
||||||
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
||||||
|
|
||||||
if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
|
if self.args.max_steps > 0 and self.global_step >= t_total:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.global_step % self.steps_per_epoch == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.train_loss.reset_states()
|
self.train_loss.reset_states()
|
||||||
|
|
||||||
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
|
||||||
logger.info("Training took: {}".format(str(end_time - start_time)))
|
logger.info("Training took: {}".format(str(end_time - start_time)))
|
||||||
|
|
Loading…
Reference in New Issue