fix documentation for CustomTrainer (#25635)

fix doc
This commit is contained in:
mchau 2023-08-21 22:23:17 +07:00 committed by GitHub
parent 8608bf2049
commit 6f041fcbb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -60,7 +60,7 @@ from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")