ElectraForQuestionAnswering (#4913)
* ElectraForQuestionAnswering * udate __init__ * add test for electra qa model * add ElectraForQuestionAnswering in auto models * add ElectraForQuestionAnswering in all_model_classes * fix outputs, input_ids defaults to None * add ElectraForQuestionAnswering in docs * remove commented line
This commit is contained in:
parent
5d63ca6c38
commit
ef2dcdccaa
|
@ -106,6 +106,13 @@ ElectraForTokenClassification
|
|||
:members:
|
||||
|
||||
|
||||
ElectraForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
||||
TFElectraModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -319,6 +319,7 @@ if is_torch_available():
|
|||
ElectraForTokenClassification,
|
||||
ElectraPreTrainedModel,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForQuestionAnswering,
|
||||
ElectraModel,
|
||||
load_tf_weights_in_electra,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
|
|
@ -78,6 +78,7 @@ from .modeling_distilbert import (
|
|||
from .modeling_electra import (
|
||||
ElectraForMaskedLM,
|
||||
ElectraForPreTraining,
|
||||
ElectraForQuestionAnswering,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForTokenClassification,
|
||||
ElectraModel,
|
||||
|
@ -237,6 +238,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||
(XLNetConfig, XLNetForQuestionAnsweringSimple),
|
||||
(FlaubertConfig, FlaubertForQuestionAnsweringSimple),
|
||||
(XLMConfig, XLMForQuestionAnsweringSimple),
|
||||
(ElectraConfig, ElectraForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -742,3 +742,119 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|||
output += discriminator_hidden_states[1:]
|
||||
|
||||
return output # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
ELECTRA_INPUTS_DOCSTRING,
|
||||
)
|
||||
class ElectraForQuestionAnswering(ElectraPreTrainedModel):
|
||||
config_class = ElectraConfig
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraTokenizer, ElectraForQuestionAnswering
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
|
||||
model = ElectraForQuestionAnswering.from_pretrained('google/electra-base-discriminator')
|
||||
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
input_ids = tokenizer.encode(question, text)
|
||||
start_scores, end_scores = model(torch.tensor([input_ids]))
|
||||
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
|
||||
|
||||
"""
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
|
|
@ -31,6 +31,7 @@ if is_torch_available():
|
|||
ElectraForTokenClassification,
|
||||
ElectraForPreTraining,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForQuestionAnswering,
|
||||
)
|
||||
from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
@ -45,6 +46,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
ElectraForMaskedLM,
|
||||
ElectraForTokenClassification,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
|
@ -276,6 +278,36 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_electra_for_question_answering(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
model = ElectraForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
|
@ -318,6 +350,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
|
Loading…
Reference in New Issue