Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
This commit is contained in:
parent
467573ddde
commit
5ae935d233
|
@ -628,18 +628,7 @@ class Trainer:
|
|||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if (
|
||||
model_path is not None
|
||||
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
self._load_optimizer_and_scheduler(model_path)
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
model = self.model
|
||||
|
@ -919,6 +908,34 @@ class Trainer:
|
|||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, model_path):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
if (
|
||||
model_path is not None
|
||||
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
|
||||
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
|
|
|
@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
|
|||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar = tqdm(total=state.max_steps)
|
||||
self.current_step = 0
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar.update(1)
|
||||
self.training_bar.update(state.global_step - self.current_step)
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import List, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
|
@ -33,8 +34,6 @@ from .utils import logging
|
|||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
|
|||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
# Reissue warnings that are not the SAVE_STATE_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue