Reload checkpoint (#7984)

* Fix checkpoint loading in Trainer

* Fix typo
This commit is contained in:
Sylvain Gugger 2020-10-22 15:48:52 -04:00 committed by GitHub
parent 467573ddde
commit 5ae935d233
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 17 deletions

View File

@ -628,18 +628,7 @@ class Trainer:
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
self._load_optimizer_and_scheduler(model_path)
# Mixed precision training with apex (torch < 1.6)
model = self.model
@ -919,6 +908,34 @@ class Trainer:
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
def _load_optimizer_and_scheduler(self, model_path):
"""If optimizer and scheduler states exist, load them."""
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,

View File

@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar = tqdm(total=state.max_steps)
self.current_step = 0
def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.update(1)
self.training_bar.update(state.global_step - self.current_step)
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero:

View File

@ -23,6 +23,7 @@ from typing import List, Optional, Union
import numpy as np
import torch
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
@ -33,8 +34,6 @@ from .utils import logging
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
logger = logging.get_logger(__name__)
@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
# Reissue warnings that are not the SAVE_STATE_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
warnings.warn(w.message, w.category)