Test checkpointing (#11682)
* Add test and see where CI is unhappy * Load with strict=False
This commit is contained in:
parent
d9b286272c
commit
f13f1f8fb8
|
@ -1059,7 +1059,18 @@ class Trainer:
|
|||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
self.model.load_state_dict(state_dict)
|
||||
load_result = self.model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result.missing_keys) != 0:
|
||||
if load_result.missing_keys == self.model._keys_to_ignore_on_save:
|
||||
self.model.tie_weights()
|
||||
else:
|
||||
logger.warn(
|
||||
f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}."
|
||||
)
|
||||
if len(load_result.unexpected_keys) != 0:
|
||||
logger.warn(
|
||||
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
||||
)
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
if model_reloaded:
|
||||
|
|
|
@ -177,6 +177,13 @@ class ModelTesterMixin:
|
|||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||
self.assertTrue(
|
||||
len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save
|
||||
)
|
||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
|
|
Loading…
Reference in New Issue