2.4 KiB
Trainer
The [Trainer
] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp
for PyTorch. [Trainer
] goes hand-in-hand with the [TrainingArguments
] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.
[Seq2SeqTrainer
] and [Seq2SeqTrainingArguments
] inherit from the [Trainer
] and [TrainingArgument
] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.
The [Trainer
] class is optimized for 🤗 Transformers models and can have surprising behaviors
when used with other models. When using it with your own model, make sure:
- your model always return tuples or subclasses of [
~utils.ModelOutput
] - your model can compute the loss if a
labels
argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples) - your model can accept multiple label arguments (use
label_names
in [TrainingArguments
] to indicate their name to the [Trainer
]) but none of them should be named"label"
Trainerapi-reference
autodoc Trainer - all
Seq2SeqTrainer
autodoc Seq2SeqTrainer - evaluate - predict
TrainingArguments
autodoc TrainingArguments - all
Seq2SeqTrainingArguments
autodoc Seq2SeqTrainingArguments - all