added functionality for electra classification head (#4257)

* added functionality for electra classification head

* unneeded dropout

* Test ELECTRA for sequence classification

* Style

Co-authored-by: Frankie <frankie@frase.io>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Frankie Liuzzi 2020-05-22 09:48:21 -04:00 committed by GitHub
parent a086527727
commit bd6e301832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 140 additions and 0 deletions

View File

@ -321,6 +321,7 @@ if is_torch_available():
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraPreTrainedModel,
ElectraForSequenceClassification,
ElectraModel,
load_tf_weights_in_electra,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,

View File

@ -88,6 +88,7 @@ from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
ElectraForMaskedLM,
ElectraForPreTraining,
ElectraForSequenceClassification,
ElectraForTokenClassification,
ElectraModel,
)
@ -251,6 +252,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(XLNetConfig, XLNetForSequenceClassification),
(FlaubertConfig, FlaubertForSequenceClassification),
(XLMConfig, XLMForSequenceClassification),
(ElectraConfig, ElectraForSequenceClassification),
]
)

View File

@ -3,6 +3,7 @@ import os
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import get_activation
from .configuration_electra import ElectraConfig
@ -330,6 +331,112 @@ class ElectraModel(ElectraPreTrainedModel):
return hidden_states
class ElectraClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = get_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here
x = self.dropout(x)
x = self.out_proj(x)
return x
@add_start_docstrings(
"""ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
ELECTRA_START_DOCSTRING,
)
class ElectraForSequenceClassification(ElectraPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.electra = ElectraModel(config)
self.classifier = ElectraClassificationHead(config)
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import BertTokenizer, BertForSequenceClassification
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
)
sequence_output = discriminator_hidden_states[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + discriminator_hidden_states[2:] # add hidden states and attention if they are here
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""
Electra model with a binary classification head on top as used during pre-training for identifying generated

View File

@ -30,6 +30,7 @@ if is_torch_available():
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraForPreTraining,
ElectraForSequenceClassification,
)
from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
@ -242,6 +243,31 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_electra_for_sequence_classification(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
):
config.num_labels = self.num_labels
model = ElectraForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -280,6 +306,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in list(ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: