Use generators tqdm progressbars (#6696)
This commit is contained in:
parent
a99d09c6f9
commit
f5bad031bc
|
@ -641,8 +641,8 @@ class Trainer:
|
|||
logging_loss = 0.0
|
||||
model.zero_grad()
|
||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
||||
for epoch in train_iterator:
|
||||
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
||||
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
|
@ -650,19 +650,21 @@ class Trainer:
|
|||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||
self.args.device
|
||||
)
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
|
||||
epoch_iterator = parallel_loader
|
||||
else:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
|
||||
epoch_iterator = train_dataloader
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
epoch_pbar.update(1)
|
||||
continue
|
||||
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
|
@ -745,11 +747,12 @@ class Trainer:
|
|||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
epoch_pbar.close()
|
||||
train_pbar.update(1)
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
|
@ -761,6 +764,7 @@ class Trainer:
|
|||
"configured. Check your training configuration if this is unexpected."
|
||||
)
|
||||
|
||||
train_pbar.close()
|
||||
if self.tb_writer:
|
||||
self.tb_writer.close()
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
|
|
Loading…
Reference in New Issue