fix ZeroDivisionError and epoch counting (#7125)
* fix ZeroDivisionError and epoch counting * Add test for num_train_epochs calculation in trainer.py * Remove @require_non_multigpu for test_num_train_epochs_in_training
This commit is contained in:
parent
7af2791d77
commit
4c62c6021a
|
@ -606,13 +606,15 @@ class Trainer:
|
|||
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
if self.args.max_steps > 0:
|
||||
t_total = self.args.max_steps
|
||||
num_train_epochs = (
|
||||
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
|
||||
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
||||
self.args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
else:
|
||||
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
|
||||
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
self.args.max_steps = t_total
|
||||
|
||||
|
@ -682,10 +684,8 @@ class Trainer:
|
|||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = self.global_step % (
|
||||
len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
)
|
||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
|
|
|
@ -302,3 +302,18 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||
# It should give 1 update step for each epoch.
|
||||
trainer = get_regression_trainer(
|
||||
max_steps=3, train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5
|
||||
)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 3)
|
||||
|
||||
# Even ``max_steps`` is not specified, we still expect 1 update step for each epoch if
|
||||
# len(train_dl) < gradient_accumulation_steps.
|
||||
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
|
Loading…
Reference in New Issue