Merge pull request #70 from deepset-ai/fix_lm_loss
fix typo in input for masked lm loss function
This commit is contained in:
commit
8c7267f1cf
|
@ -678,7 +678,7 @@ class BertForPreTraining(PreTrainedBertModel):
|
|||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
return total_loss
|
||||
|
|
Loading…
Reference in New Issue