Introduce Stateful Callbacks (#29666)

* Introduce saveable callbacks

* Add note

* Test for non-present and flag

* Support early stopping and refusing to train further

* Update docstring

* More saving

* Import oopsie

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make it go through TrainerArguments

* Document

* Fix test

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Rework to allow for duplicates

* CLean

* Fix failing tests

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2024-04-25 11:00:09 -04:00 committed by GitHub
parent 86f2569738
commit ad697f1801
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 337 additions and 7 deletions

View File

@ -78,6 +78,7 @@ from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
ExportableState,
PrinterCallback,
ProgressCallback,
TrainerCallback,
@ -649,12 +650,15 @@ class Trainer:
else:
self.label_smoother = None
self.control = TrainerControl()
self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.control = TrainerControl()
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged
self.current_flos = 0
@ -1499,6 +1503,8 @@ class Trainer:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
@ -1970,7 +1976,11 @@ class Trainer:
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
@ -2079,6 +2089,7 @@ class Trainer:
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
@ -2786,6 +2797,8 @@ class Trainer:
# Save the Trainer state
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
@ -2970,6 +2983,45 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
def _load_callback_state(self):
"""If callback states exist and were passed in, restore their states if enabled"""
if not self.args.restore_callback_states_from_checkpoint:
return
# Callback states are stored in stateful_callbacks
not_found = []
new_callbacks = []
original_callbacks = self.callback_handler.callbacks + [self.control]
for stored_callback, data in self.state.stateful_callbacks.items():
if not isinstance(data, list):
data = [data]
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
# We can load/restore from multiple callbacks of the same type.
duplicates = [
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
]
for callback, callback_data in zip(duplicates, data):
args = callback_data.get("args", {})
attributes = callback_data.get("attributes", {})
new_callback = type(callback)(**args)
for attribute, value in attributes.items():
setattr(new_callback, attribute, value)
if isinstance(callback, TrainerControl):
# Specifically for restoring the `control` state
self.control = new_callback
else:
new_callbacks.append(new_callback)
# We remove the existing callback and add it to the list of new callbacks
self.callback_handler.remove_callback(type(new_callback))
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
else:
not_found.append(stored_callback)
if len(not_found) > 0:
logger.warning(
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
)
for callback in new_callbacks:
self.callback_handler.add_callback(callback)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,

View File

@ -84,6 +84,9 @@ class TrainerState:
is_hyper_param_search (`bool`, *optional*, defaults to `False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
impact the way data will be logged in TensorBoard.
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
Callbacks attached to the `Trainer` that should have their states be saved or restored.
Relevent callbacks should implement a `state` and `from_state` function.
"""
epoch: Optional[float] = None
@ -104,10 +107,34 @@ class TrainerState:
is_hyper_param_search: bool = False
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None
stateful_callbacks: List["TrainerCallback"] = None
def __post_init__(self):
if self.log_history is None:
self.log_history = []
if self.stateful_callbacks is None:
self.stateful_callbacks = {}
elif isinstance(self.stateful_callbacks, dict):
# We are loading the callbacks in from the state file, no need to process them
pass
else:
# Saveable callbacks get stored as dict of kwargs
stateful_callbacks = {}
for callback in self.stateful_callbacks:
if not isinstance(callback, (ExportableState)):
raise TypeError(
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
)
name = callback.__class__.__name__
if name in stateful_callbacks:
# We can have multiple versions of the same callback
# if so, we store them as a list of states to restore
if not isinstance(stateful_callbacks[name], list):
stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[name] = callback.state()
self.stateful_callbacks = stateful_callbacks
def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
@ -123,8 +150,52 @@ class TrainerState:
return cls(**json.loads(text))
class ExportableState:
"""
A class for objects that include the ability to have its state
be saved during `Trainer._save_checkpoint` and loaded back in during
`Trainer._load_from_checkpoint`.
These must implement a `state` function that gets called during the respective
Trainer function call. It should only include parameters and attributes needed to
recreate the state at a particular time, to avoid utilizing pickle/maintain standard
file IO writing.
Example:
```python
class EarlyStoppingCallback(TrainerCallback, ExportableState):
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
}
}
```"""
def state(self) -> dict:
raise NotImplementedError("You must implement a `state` function to utilize this class.")
@classmethod
def from_state(cls, state):
instance = cls(**state["args"])
for k, v in state["attributes"].items():
setattr(instance, k, v)
return instance
@dataclass
class TrainerControl:
class TrainerControl(ExportableState):
"""
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
switches in the training loop.
@ -172,6 +243,18 @@ class TrainerControl:
self.should_evaluate = False
self.should_log = False
def state(self) -> dict:
return {
"args": {
"should_training_stop": self.should_training_stop,
"should_epoch_stop": self.should_epoch_stop,
"should_save": self.should_save,
"should_evaluate": self.should_evaluate,
"should_log": self.should_log,
},
"attributes": {},
}
class TrainerCallback:
# no-format
@ -546,7 +629,7 @@ class PrinterCallback(TrainerCallback):
print(logs)
class EarlyStoppingCallback(TrainerCallback):
class EarlyStoppingCallback(TrainerCallback, ExportableState):
"""
A [`TrainerCallback`] that handles early stopping.
@ -605,3 +688,14 @@ class EarlyStoppingCallback(TrainerCallback):
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True
def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
},
}

View File

@ -357,6 +357,9 @@ class TrainingArguments:
Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`.
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to restore the callback states from the checkpoint. If `True`, will override
callbacks passed to the `Trainer` if they exist in the checkpoint."
use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42):
@ -951,6 +954,12 @@ class TrainingArguments:
)
},
)
restore_callback_states_from_checkpoint: bool = field(
default=False,
metadata={
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
},
)
no_cuda: bool = field(
default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
@ -19,28 +21,44 @@ from unittest.mock import patch
from transformers import (
DefaultFlowCallback,
EarlyStoppingCallback,
IntervalStrategy,
PrinterCallback,
ProgressCallback,
Trainer,
TrainerCallback,
TrainerState,
TrainingArguments,
is_torch_available,
)
from transformers.testing_utils import require_torch
from transformers.trainer_callback import ExportableState
if is_torch_available():
from transformers.trainer import DEFAULT_CALLBACKS
from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
class MyTestExportableCallback(TrainerCallback, ExportableState):
def __init__(self, my_test_state="test"):
self.my_test_state = my_test_state
def state(self):
return {
"args": {
"my_test_state": self.my_test_state,
},
}
class MyTestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through."
def __init__(self):
def __init__(self, my_test_state="test"):
self.events = []
self.my_test_state = my_test_state
def on_init_end(self, args, state, control, **kwargs):
self.events.append("on_init_end")
@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase):
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
def test_stateful_callbacks(self):
# Use something with non-defaults
cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
trainer = self.get_trainer(
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[EarlyStoppingCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cb = [
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
][0]
assert cb.early_stopping_patience == 5
assert cb.early_stopping_threshold == 0.2
def test_stateful_mixed_callbacks(self):
# Use two callbacks, one stateful one not
# Use something with non-defaults
cbs = [
MyTestTrainerCallback(my_test_state="another value"),
EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
]
trainer = self.get_trainer(
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cbs = [
callback
for callback in trainer.callback_handler.callbacks
if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
]
assert len(cbs) == 2
my_test, early_stopping = cbs
assert early_stopping.early_stopping_patience == 5
assert early_stopping.early_stopping_threshold == 0.2
assert my_test.my_test_state == "test"
def test_stateful_duplicate_callbacks(self):
# Use something with non-defaults
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
trainer = self.get_trainer(
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cbs = [
callback
for callback in trainer.callback_handler.callbacks
if isinstance(callback, MyTestExportableCallback)
]
assert len(cbs) == 2
assert cbs[0].my_test_state == "first"
assert cbs[1].my_test_state == "second"
def test_missing_stateful_callback(self):
cb = EarlyStoppingCallback()
trainer = self.get_trainer(
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
# warning should be emitted for not-present callbacks
with patch("transformers.trainer.logger.warning") as warn_mock:
trainer.train(resume_from_checkpoint=checkpoint)
assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]
def test_stateful_control(self):
trainer = self.get_trainer(
max_steps=2,
save_strategy="steps",
save_steps=2,
)
trainer.train()
# Load it back in and verify values
trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
trainer._load_callback_state()
assert trainer.control.should_training_stop