Correct assignement for logits in classifier example

I tried to address https://github.com/huggingface/pytorch-pretrained-BERT/issues/76
should be correct, but there's likely a more efficient way.
This commit is contained in:
Davide Fiocco 2018-12-02 12:38:26 +01:00 committed by GitHub
parent 063be09b71
commit e60e8a6068
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -605,7 +605,8 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()