fix return dicitonary labels from masked_lm_labels to labels (#7595)

This commit is contained in:
George Mihaila 2020-10-06 08:12:04 -05:00 committed by GitHub
parent 8d2c248df7
commit 4d541f516f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -446,7 +446,7 @@ class DataCollatorForNextSentencePrediction:
"input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids),
"masked_lm_labels": mlm_labels if self.mlm else None,
"labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}
if self.mlm: