Gpt1 for sequence classification (#7683)

* Add Documentation for GPT-1 Classification

* Add GPT-1 with Classification head

* Add tests for GPT-1 Classification

* Add GPT-1 For Classification to auto models

* Remove authorized missing keys, change checkpoint to openai-gpt
This commit is contained in:
Felipe Curti 2020-10-13 06:06:15 -03:00 committed by GitHub
parent f34b4cd1bd
commit dcba9ee03b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 153 additions and 4 deletions

View File

@ -104,6 +104,13 @@ OpenAIGPTDoubleHeadsModel
:members: forward
OpenAIGPTForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.OpenAIGPTForSequenceClassification
:members: forward
TFOpenAIGPTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -437,6 +437,7 @@ if is_torch_available():
from .modeling_openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTForSequenceClassification,
OpenAIGPTLMHeadModel,
OpenAIGPTModel,
OpenAIGPTPreTrainedModel,

View File

@ -153,7 +153,7 @@ from .modeling_mobilebert import (
MobileBertForTokenClassification,
MobileBertModel,
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
@ -381,6 +381,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(FunnelConfig, FunnelForSequenceClassification),
(DebertaConfig, DebertaForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
]
)

View File

@ -25,7 +25,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu_new, swish
from .configuration_openai import OpenAIGPTConfig
@ -36,7 +36,7 @@ from .file_utils import (
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from .modeling_utils import (
Conv1D,
PreTrainedModel,
@ -732,3 +732,113 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""The Original OpenAI GPT Model transformer with a sequence classification head on top
(linear layer).
:class:`~transformers.OpenAIGPTForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token.
If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token
in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch.
Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it
does the same (take the last value in each row of the batch).
""",
OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = OpenAIGPTModel(config)
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="openai-gpt",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -1256,6 +1256,15 @@ class OpenAIGPTDoubleHeadsModel:
requires_pytorch(self)
class OpenAIGPTForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class OpenAIGPTLMHeadModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -30,6 +30,7 @@ if is_torch_available():
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTConfig,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTForSequenceClassification,
OpenAIGPTLMHeadModel,
OpenAIGPTModel,
)
@ -61,6 +62,7 @@ class OpenAIGPTModelTester:
self.num_labels = 3
self.num_choices = 4
self.scope = None
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -90,6 +92,7 @@ class OpenAIGPTModelTester:
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
pad_token_id=self.pad_token_id,
return_dict=True,
)
@ -134,6 +137,18 @@ class OpenAIGPTModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_openai_gpt_for_sequence_classification(
self, config, input_ids, head_mask, token_type_ids, *args
):
config.num_labels = self.num_labels
model = OpenAIGPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
# print(config.num_labels, sequence_labels.size())
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -158,7 +173,9 @@ class OpenAIGPTModelTester:
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(OpenAIGPTLMHeadModel,) if is_torch_available() else ()
@ -183,6 +200,10 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
def test_openai_gpt_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: