transformers/docs/source/en/main_classes/trainer.md

2.4 KiB

Trainer

The [Trainer] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp for PyTorch. [Trainer] goes hand-in-hand with the [TrainingArguments] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.

[Seq2SeqTrainer] and [Seq2SeqTrainingArguments] inherit from the [Trainer] and [TrainingArgument] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.

The [Trainer] class is optimized for 🤗 Transformers models and can have surprising behaviors when used with other models. When using it with your own model, make sure:

  • your model always return tuples or subclasses of [~utils.ModelOutput]
  • your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples)
  • your model can accept multiple label arguments (use label_names in [TrainingArguments] to indicate their name to the [Trainer]) but none of them should be named "label"

Trainerapi-reference

autodoc Trainer - all

Seq2SeqTrainer

autodoc Seq2SeqTrainer - evaluate - predict

TrainingArguments

autodoc TrainingArguments - all

Seq2SeqTrainingArguments

autodoc Seq2SeqTrainingArguments - all