From f5bad031bcab6d431b74d489e8cd238965f94ddb Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 25 Aug 2020 07:06:58 -0400 Subject: [PATCH] Use generators tqdm progressbars (#6696) --- src/transformers/trainer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c73b7f0f34..5a8d0a9709 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -641,8 +641,8 @@ class Trainer: logging_loss = 0.0 model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() - train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) - for epoch in train_iterator: + train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) + for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -650,19 +650,21 @@ class Trainer: parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( self.args.device ) - epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm) + epoch_iterator = parallel_loader else: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm) + epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if self.args.past_index >= 0: self._past = None + epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 + epoch_pbar.update(1) continue tr_loss += self.training_step(model, inputs) @@ -745,11 +747,12 @@ class Trainer: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + epoch_pbar.update(1) if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: - epoch_iterator.close() break + epoch_pbar.close() + train_pbar.update(1) if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: - train_iterator.close() break if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): @@ -761,6 +764,7 @@ class Trainer: "configured. Check your training configuration if this is unexpected." ) + train_pbar.close() if self.tb_writer: self.tb_writer.close() if self.args.past_index and hasattr(self, "_past"):