Gpt1 for sequence classification (#7683)
* Add Documentation for GPT-1 Classification * Add GPT-1 with Classification head * Add tests for GPT-1 Classification * Add GPT-1 For Classification to auto models * Remove authorized missing keys, change checkpoint to openai-gpt
This commit is contained in:
parent
f34b4cd1bd
commit
dcba9ee03b
|
@ -104,6 +104,13 @@ OpenAIGPTDoubleHeadsModel
|
|||
:members: forward
|
||||
|
||||
|
||||
OpenAIGPTForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.OpenAIGPTForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFOpenAIGPTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -437,6 +437,7 @@ if is_torch_available():
|
|||
from .modeling_openai import (
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
OpenAIGPTForSequenceClassification,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTModel,
|
||||
OpenAIGPTPreTrainedModel,
|
||||
|
|
|
@ -153,7 +153,7 @@ from .modeling_mobilebert import (
|
|||
MobileBertForTokenClassification,
|
||||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
||||
RagModel,
|
||||
|
@ -381,6 +381,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||
(FunnelConfig, FunnelForSequenceClassification),
|
||||
(DebertaConfig, DebertaForSequenceClassification),
|
||||
(GPT2Config, GPT2ForSequenceClassification),
|
||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .activations import gelu_new, swish
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
|
@ -36,7 +36,7 @@ from .file_utils import (
|
|||
add_start_docstrings_to_callable,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from .modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from .modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
|
@ -732,3 +732,113 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Original OpenAI GPT Model transformer with a sequence classification head on top
|
||||
(linear layer).
|
||||
:class:`~transformers.OpenAIGPTForSequenceClassification` uses the last token in order to do the classification, as
|
||||
other causal models (e.g. GPT-2) do.
|
||||
Since it does classification on the last token, it requires to know the position of the last token.
|
||||
If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token
|
||||
in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch.
|
||||
Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it
|
||||
does the same (take the last value in each row of the batch).
|
||||
""",
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="openai-gpt",
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -1256,6 +1256,15 @@ class OpenAIGPTDoubleHeadsModel:
|
|||
requires_pytorch(self)
|
||||
|
||||
|
||||
class OpenAIGPTForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class OpenAIGPTLMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
|
|
@ -30,6 +30,7 @@ if is_torch_available():
|
|||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
OpenAIGPTForSequenceClassification,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTModel,
|
||||
)
|
||||
|
@ -61,6 +62,7 @@ class OpenAIGPTModelTester:
|
|||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
self.pad_token_id = self.vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
@ -90,6 +92,7 @@ class OpenAIGPTModelTester:
|
|||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
pad_token_id=self.pad_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
|
@ -134,6 +137,18 @@ class OpenAIGPTModelTester:
|
|||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_openai_gpt_for_sequence_classification(
|
||||
self, config, input_ids, head_mask, token_type_ids, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = OpenAIGPTForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
# print(config.num_labels, sequence_labels.size())
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
|
@ -158,7 +173,9 @@ class OpenAIGPTModelTester:
|
|||
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
|
||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(OpenAIGPTLMHeadModel,) if is_torch_available() else ()
|
||||
|
@ -183,6 +200,10 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_openai_gpt_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
|
Loading…
Reference in New Issue