From ef2dcdccaa9a115aca44d81f31c6dc4d32bebb3f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 11 Jun 2020 00:47:52 +0530 Subject: [PATCH] ElectraForQuestionAnswering (#4913) * ElectraForQuestionAnswering * udate __init__ * add test for electra qa model * add ElectraForQuestionAnswering in auto models * add ElectraForQuestionAnswering in all_model_classes * fix outputs, input_ids defaults to None * add ElectraForQuestionAnswering in docs * remove commented line --- docs/source/model_doc/electra.rst | 7 ++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 + src/transformers/modeling_electra.py | 116 +++++++++++++++++++++++++++ tests/test_modeling_electra.py | 36 +++++++++ 5 files changed, 162 insertions(+) diff --git a/docs/source/model_doc/electra.rst b/docs/source/model_doc/electra.rst index 98a21ab25e..431b4f2717 100644 --- a/docs/source/model_doc/electra.rst +++ b/docs/source/model_doc/electra.rst @@ -106,6 +106,13 @@ ElectraForTokenClassification :members: +ElectraForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ElectraForQuestionAnswering + :members: + + TFElectraModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ca973acec5..674ab5850d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -319,6 +319,7 @@ if is_torch_available(): ElectraForTokenClassification, ElectraPreTrainedModel, ElectraForSequenceClassification, + ElectraForQuestionAnswering, ElectraModel, load_tf_weights_in_electra, ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index ca25c5cb76..31236e1e47 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -78,6 +78,7 @@ from .modeling_distilbert import ( from .modeling_electra import ( ElectraForMaskedLM, ElectraForPreTraining, + ElectraForQuestionAnswering, ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel, @@ -237,6 +238,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (XLNetConfig, XLNetForQuestionAnsweringSimple), (FlaubertConfig, FlaubertForQuestionAnsweringSimple), (XLMConfig, XLMForQuestionAnsweringSimple), + (ElectraConfig, ElectraForQuestionAnswering), ] ) diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index e85a57cddd..df8ef8c10e 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -742,3 +742,119 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): output += discriminator_hidden_states[1:] return output # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + ELECTRA_INPUTS_DOCSTRING, +) +class ElectraForQuestionAnswering(ElectraPreTrainedModel): + config_class = ElectraConfig + base_model_prefix = "electra" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.electra = ElectraModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import ElectraTokenizer, ElectraForQuestionAnswering + import torch + + tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator') + model = ElectraForQuestionAnswering.from_pretrained('google/electra-base-discriminator') + + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + input_ids = tokenizer.encode(question, text) + start_scores, end_scores = model(torch.tensor([input_ids])) + + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) + + """ + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + sequence_output = discriminator_hidden_states[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 9c0a676d26..d09fcc4b61 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -31,6 +31,7 @@ if is_torch_available(): ElectraForTokenClassification, ElectraForPreTraining, ElectraForSequenceClassification, + ElectraForQuestionAnswering, ) from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST @@ -45,6 +46,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ElectraForMaskedLM, ElectraForTokenClassification, ElectraForSequenceClassification, + ElectraForQuestionAnswering, ) if is_torch_available() else () @@ -276,6 +278,36 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.check_loss_output(result) + def create_and_check_electra_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + fake_token_labels, + ): + model = ElectraForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + loss, start_logits, end_logits = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -318,6 +350,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: