parent
8608bf2049
commit
6f041fcbb8
|
@ -60,7 +60,7 @@ from transformers import Trainer
|
|||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.get("labels")
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
|
|
Loading…
Reference in New Issue