From 9de9ceb6c5fd53cb600db7539415baa045916e0f Mon Sep 17 00:00:00 2001 From: Anirudh Srinivasan Date: Thu, 2 Apr 2020 00:34:38 +0530 Subject: [PATCH] Correct output shape for Bert NSP models in docs (#3482) --- src/transformers/modeling_bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 47d7b2301f..a7bc4df0de 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -845,7 +845,7 @@ class BertForPreTraining(BertPreTrainedModel): Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): @@ -1048,7 +1048,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): Next sequence prediction (classification) loss. - seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)