Use generators tqdm progressbars (#6696)

This commit is contained in:
Sylvain Gugger 2020-08-25 07:06:58 -04:00 committed by GitHub
parent a99d09c6f9
commit f5bad031bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 6 deletions

View File

@ -641,8 +641,8 @@ class Trainer:
logging_loss = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in train_iterator:
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
@ -650,19 +650,21 @@ class Trainer:
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
epoch_iterator = parallel_loader
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
epoch_pbar.update(1)
continue
tr_loss += self.training_step(model, inputs)
@ -745,11 +747,12 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
epoch_iterator.close()
break
epoch_pbar.close()
train_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
@ -761,6 +764,7 @@ class Trainer:
"configured. Check your training configuration if this is unexpected."
)
train_pbar.close()
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):