Introduce Stateful Callbacks (#29666)
* Introduce saveable callbacks * Add note * Test for non-present and flag * Support early stopping and refusing to train further * Update docstring * More saving * Import oopsie * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Make it go through TrainerArguments * Document * Fix test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Rework to allow for duplicates * CLean * Fix failing tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
86f2569738
commit
ad697f1801
|
@ -78,6 +78,7 @@ from .tokenization_utils_base import PreTrainedTokenizerBase
|
|||
from .trainer_callback import (
|
||||
CallbackHandler,
|
||||
DefaultFlowCallback,
|
||||
ExportableState,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
TrainerCallback,
|
||||
|
@ -649,12 +650,15 @@ class Trainer:
|
|||
else:
|
||||
self.label_smoother = None
|
||||
|
||||
self.control = TrainerControl()
|
||||
|
||||
self.state = TrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
stateful_callbacks=[
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
],
|
||||
)
|
||||
|
||||
self.control = TrainerControl()
|
||||
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
|
||||
# returned to 0 every time flos need to be logged
|
||||
self.current_flos = 0
|
||||
|
@ -1499,6 +1503,8 @@ class Trainer:
|
|||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
if self.args.should_save:
|
||||
# Update the `TrainerControl` state to where we are currently
|
||||
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
|
@ -1970,7 +1976,11 @@ class Trainer:
|
|||
if not delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
self.state = TrainerState()
|
||||
self.state = TrainerState(
|
||||
stateful_callbacks=[
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
]
|
||||
)
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
self.state.train_batch_size = self._train_batch_size
|
||||
|
||||
|
@ -2079,6 +2089,7 @@ class Trainer:
|
|||
):
|
||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
||||
self._load_callback_state()
|
||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||
if not args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
|
@ -2786,6 +2797,8 @@ class Trainer:
|
|||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
# Update the `TrainerControl` state to where we are currently
|
||||
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
|
@ -2970,6 +2983,45 @@ class Trainer:
|
|||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _load_callback_state(self):
|
||||
"""If callback states exist and were passed in, restore their states if enabled"""
|
||||
if not self.args.restore_callback_states_from_checkpoint:
|
||||
return
|
||||
# Callback states are stored in stateful_callbacks
|
||||
not_found = []
|
||||
new_callbacks = []
|
||||
original_callbacks = self.callback_handler.callbacks + [self.control]
|
||||
for stored_callback, data in self.state.stateful_callbacks.items():
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
|
||||
# We can load/restore from multiple callbacks of the same type.
|
||||
duplicates = [
|
||||
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
|
||||
]
|
||||
for callback, callback_data in zip(duplicates, data):
|
||||
args = callback_data.get("args", {})
|
||||
attributes = callback_data.get("attributes", {})
|
||||
new_callback = type(callback)(**args)
|
||||
for attribute, value in attributes.items():
|
||||
setattr(new_callback, attribute, value)
|
||||
if isinstance(callback, TrainerControl):
|
||||
# Specifically for restoring the `control` state
|
||||
self.control = new_callback
|
||||
else:
|
||||
new_callbacks.append(new_callback)
|
||||
# We remove the existing callback and add it to the list of new callbacks
|
||||
self.callback_handler.remove_callback(type(new_callback))
|
||||
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
|
||||
else:
|
||||
not_found.append(stored_callback)
|
||||
if len(not_found) > 0:
|
||||
logger.warning(
|
||||
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
|
||||
)
|
||||
for callback in new_callbacks:
|
||||
self.callback_handler.add_callback(callback)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
|
|
|
@ -84,6 +84,9 @@ class TrainerState:
|
|||
is_hyper_param_search (`bool`, *optional*, defaults to `False`):
|
||||
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
|
||||
impact the way data will be logged in TensorBoard.
|
||||
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
|
||||
Callbacks attached to the `Trainer` that should have their states be saved or restored.
|
||||
Relevent callbacks should implement a `state` and `from_state` function.
|
||||
"""
|
||||
|
||||
epoch: Optional[float] = None
|
||||
|
@ -104,10 +107,34 @@ class TrainerState:
|
|||
is_hyper_param_search: bool = False
|
||||
trial_name: str = None
|
||||
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
||||
stateful_callbacks: List["TrainerCallback"] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_history is None:
|
||||
self.log_history = []
|
||||
if self.stateful_callbacks is None:
|
||||
self.stateful_callbacks = {}
|
||||
elif isinstance(self.stateful_callbacks, dict):
|
||||
# We are loading the callbacks in from the state file, no need to process them
|
||||
pass
|
||||
else:
|
||||
# Saveable callbacks get stored as dict of kwargs
|
||||
stateful_callbacks = {}
|
||||
for callback in self.stateful_callbacks:
|
||||
if not isinstance(callback, (ExportableState)):
|
||||
raise TypeError(
|
||||
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
|
||||
)
|
||||
name = callback.__class__.__name__
|
||||
if name in stateful_callbacks:
|
||||
# We can have multiple versions of the same callback
|
||||
# if so, we store them as a list of states to restore
|
||||
if not isinstance(stateful_callbacks[name], list):
|
||||
stateful_callbacks[name] = [stateful_callbacks[name]]
|
||||
stateful_callbacks[name].append(callback.state())
|
||||
else:
|
||||
stateful_callbacks[name] = callback.state()
|
||||
self.stateful_callbacks = stateful_callbacks
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Save the content of this instance in JSON format inside `json_path`."""
|
||||
|
@ -123,8 +150,52 @@ class TrainerState:
|
|||
return cls(**json.loads(text))
|
||||
|
||||
|
||||
class ExportableState:
|
||||
"""
|
||||
A class for objects that include the ability to have its state
|
||||
be saved during `Trainer._save_checkpoint` and loaded back in during
|
||||
`Trainer._load_from_checkpoint`.
|
||||
|
||||
These must implement a `state` function that gets called during the respective
|
||||
Trainer function call. It should only include parameters and attributes needed to
|
||||
recreate the state at a particular time, to avoid utilizing pickle/maintain standard
|
||||
file IO writing.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class EarlyStoppingCallback(TrainerCallback, ExportableState):
|
||||
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
|
||||
self.early_stopping_patience = early_stopping_patience
|
||||
self.early_stopping_threshold = early_stopping_threshold
|
||||
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
|
||||
self.early_stopping_patience_counter = 0
|
||||
|
||||
def state(self) -> dict:
|
||||
return {
|
||||
"args": {
|
||||
"early_stopping_patience": self.early_stopping_patience,
|
||||
"early_stopping_threshold": self.early_stopping_threshold,
|
||||
},
|
||||
"attributes": {
|
||||
"early_stopping_patience_counter": self.early_stopping_patience_counter,
|
||||
}
|
||||
}
|
||||
```"""
|
||||
|
||||
def state(self) -> dict:
|
||||
raise NotImplementedError("You must implement a `state` function to utilize this class.")
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
instance = cls(**state["args"])
|
||||
for k, v in state["attributes"].items():
|
||||
setattr(instance, k, v)
|
||||
return instance
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerControl:
|
||||
class TrainerControl(ExportableState):
|
||||
"""
|
||||
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
|
||||
switches in the training loop.
|
||||
|
@ -172,6 +243,18 @@ class TrainerControl:
|
|||
self.should_evaluate = False
|
||||
self.should_log = False
|
||||
|
||||
def state(self) -> dict:
|
||||
return {
|
||||
"args": {
|
||||
"should_training_stop": self.should_training_stop,
|
||||
"should_epoch_stop": self.should_epoch_stop,
|
||||
"should_save": self.should_save,
|
||||
"should_evaluate": self.should_evaluate,
|
||||
"should_log": self.should_log,
|
||||
},
|
||||
"attributes": {},
|
||||
}
|
||||
|
||||
|
||||
class TrainerCallback:
|
||||
# no-format
|
||||
|
@ -546,7 +629,7 @@ class PrinterCallback(TrainerCallback):
|
|||
print(logs)
|
||||
|
||||
|
||||
class EarlyStoppingCallback(TrainerCallback):
|
||||
class EarlyStoppingCallback(TrainerCallback, ExportableState):
|
||||
"""
|
||||
A [`TrainerCallback`] that handles early stopping.
|
||||
|
||||
|
@ -605,3 +688,14 @@ class EarlyStoppingCallback(TrainerCallback):
|
|||
self.check_metric_value(args, state, control, metric_value)
|
||||
if self.early_stopping_patience_counter >= self.early_stopping_patience:
|
||||
control.should_training_stop = True
|
||||
|
||||
def state(self) -> dict:
|
||||
return {
|
||||
"args": {
|
||||
"early_stopping_patience": self.early_stopping_patience,
|
||||
"early_stopping_threshold": self.early_stopping_threshold,
|
||||
},
|
||||
"attributes": {
|
||||
"early_stopping_patience_counter": self.early_stopping_patience_counter,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -357,6 +357,9 @@ class TrainingArguments:
|
|||
Note that when this is true, you won't be able to resume training from checkpoint.
|
||||
This enables you to save storage by not storing the optimizer, scheduler & rng state.
|
||||
You can only load the model using `from_pretrained` with this option set to `True`.
|
||||
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
|
||||
Whether to restore the callback states from the checkpoint. If `True`, will override
|
||||
callbacks passed to the `Trainer` if they exist in the checkpoint."
|
||||
use_cpu (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
|
||||
seed (`int`, *optional*, defaults to 42):
|
||||
|
@ -951,6 +954,12 @@ class TrainingArguments:
|
|||
)
|
||||
},
|
||||
)
|
||||
restore_callback_states_from_checkpoint: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
|
||||
},
|
||||
)
|
||||
no_cuda: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
@ -19,28 +21,44 @@ from unittest.mock import patch
|
|||
|
||||
from transformers import (
|
||||
DefaultFlowCallback,
|
||||
EarlyStoppingCallback,
|
||||
IntervalStrategy,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.trainer_callback import ExportableState
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
|
||||
|
||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||
|
||||
|
||||
class MyTestExportableCallback(TrainerCallback, ExportableState):
|
||||
def __init__(self, my_test_state="test"):
|
||||
self.my_test_state = my_test_state
|
||||
|
||||
def state(self):
|
||||
return {
|
||||
"args": {
|
||||
"my_test_state": self.my_test_state,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MyTestTrainerCallback(TrainerCallback):
|
||||
"A callback that registers the events that goes through."
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, my_test_state="test"):
|
||||
self.events = []
|
||||
self.my_test_state = my_test_state
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_init_end")
|
||||
|
@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||
)
|
||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||
|
||||
def test_stateful_callbacks(self):
|
||||
# Use something with non-defaults
|
||||
cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[cb],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[EarlyStoppingCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cb = [
|
||||
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
|
||||
][0]
|
||||
assert cb.early_stopping_patience == 5
|
||||
assert cb.early_stopping_threshold == 0.2
|
||||
|
||||
def test_stateful_mixed_callbacks(self):
|
||||
# Use two callbacks, one stateful one not
|
||||
# Use something with non-defaults
|
||||
cbs = [
|
||||
MyTestTrainerCallback(my_test_state="another value"),
|
||||
EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
|
||||
]
|
||||
trainer = self.get_trainer(
|
||||
callbacks=cbs,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cbs = [
|
||||
callback
|
||||
for callback in trainer.callback_handler.callbacks
|
||||
if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
|
||||
]
|
||||
assert len(cbs) == 2
|
||||
my_test, early_stopping = cbs
|
||||
assert early_stopping.early_stopping_patience == 5
|
||||
assert early_stopping.early_stopping_threshold == 0.2
|
||||
assert my_test.my_test_state == "test"
|
||||
|
||||
def test_stateful_duplicate_callbacks(self):
|
||||
# Use something with non-defaults
|
||||
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
|
||||
trainer = self.get_trainer(
|
||||
callbacks=cbs,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cbs = [
|
||||
callback
|
||||
for callback in trainer.callback_handler.callbacks
|
||||
if isinstance(callback, MyTestExportableCallback)
|
||||
]
|
||||
assert len(cbs) == 2
|
||||
assert cbs[0].my_test_state == "first"
|
||||
assert cbs[1].my_test_state == "second"
|
||||
|
||||
def test_missing_stateful_callback(self):
|
||||
cb = EarlyStoppingCallback()
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[cb],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
# warning should be emitted for not-present callbacks
|
||||
with patch("transformers.trainer.logger.warning") as warn_mock:
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]
|
||||
|
||||
def test_stateful_control(self):
|
||||
trainer = self.get_trainer(
|
||||
max_steps=2,
|
||||
save_strategy="steps",
|
||||
save_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
# Load it back in and verify values
|
||||
trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
||||
trainer._load_callback_state()
|
||||
assert trainer.control.should_training_stop
|
||||
|
|
Loading…
Reference in New Issue