4.8 KiB
Callbacks
Callbacks可以用来自定义PyTorch [Trainer]中训练循环行为的对象(此功能尚未在TensorFlow中实现),该对象可以检查训练循环状态(用于进度报告、在TensorBoard或其他ML平台上记录日志等),并做出决策(例如提前停止)。
Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]对象外,它们不能更改训练循环中的任何内容。对于需要更改训练循环的自定义,您应该继承[Trainer]并重载您需要的方法(有关示例,请参见trainer)。
默认情况下,TrainingArguments.report_to
设置为"all",然后[Trainer]将使用以下callbacks。
- [
DefaultFlowCallback
],它处理默认的日志记录、保存和评估行为 - [
PrinterCallback
] 或 [ProgressCallback
],用于显示进度和打印日志(如果通过[TrainingArguments
]停用tqdm,则使用第一个函数;否则使用第二个)。 - [
~integrations.TensorBoardCallback
],如果TensorBoard可访问(通过PyTorch版本 >= 1.4 或者 tensorboardX)。 - [
~integrations.WandbCallback
],如果安装了wandb。 - [
~integrations.CometCallback
],如果安装了comet_ml。 - [
~integrations.MLflowCallback
],如果安装了mlflow。 - [
~integrations.NeptuneCallback
],如果安装了neptune。 - [
~integrations.AzureMLCallback
],如果安装了azureml-sdk。 - [
~integrations.CodeCarbonCallback
],如果安装了codecarbon。 - [
~integrations.ClearMLCallback
],如果安装了clearml。 - [
~integrations.DagsHubCallback
],如果安装了dagshub。 - [
~integrations.FlyteCallback
],如果安装了flyte。 - [
~integrations.DVCLiveCallback
],如果安装了dvclive。
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 TrainingArguments.report_to
更改为仅包含您想要使用的集成的列表(例如 ["azure_ml", "wandb"]
)。
实现callbacks的主要类是[TrainerCallback
]。它获取用于实例化[Trainer
]的[TrainingArguments
],可以通过[TrainerState
]访问该Trainer的内部状态,并可以通过[TrainerControl
]对训练循环执行一些操作。
可用的Callbacks
这里是库里可用[TrainerCallback
]的列表:
autodoc integrations.CometCallback - setup
autodoc DefaultFlowCallback
autodoc PrinterCallback
autodoc ProgressCallback
autodoc EarlyStoppingCallback
autodoc integrations.TensorBoardCallback
autodoc integrations.WandbCallback - setup
autodoc integrations.MLflowCallback - setup
autodoc integrations.AzureMLCallback
autodoc integrations.CodeCarbonCallback
autodoc integrations.NeptuneCallback
autodoc integrations.ClearMLCallback
autodoc integrations.DagsHubCallback
autodoc integrations.FlyteCallback
autodoc integrations.DVCLiveCallback - setup
TrainerCallback
autodoc TrainerCallback
以下是如何使用PyTorch注册自定义callback的示例:
[Trainer
]:
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback())
)
注册callback的另一种方式是调用 trainer.add_callback()
,如下所示:
trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())
TrainerState
autodoc TrainerState
TrainerControl
autodoc TrainerControl