diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 8e136da37b..a5e7d2c30d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -605,7 +605,8 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) + tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) + logits = model(input_ids, segment_ids, input_mask) logits = logits.detach().cpu().numpy() label_ids = label_ids.to('cpu').numpy()