Adding [T5/MT5/UMT5]ForTokenClassification (#28443)

* Adding [T5/MT5/UMT5]ForTokenClassification

* Add auto mappings for T5ForTokenClassification and variants

* Adding ForTokenClassification to the list of models

* Adding attention_mask param to the T5ForTokenClassification test

* Remove outdated comment in test

* Adding EncoderOnly and Token Classification tests for MT5 and UMT5

* Fix typo in umt5 string

* Add tests for all the existing MT5 models

* Fix wrong comment in dependency_versions_table

* Reverting change to common test for _keys_to_ignore_on_load_missing

The test is correctly picking up redundant keys in _keys_to_ignore_on_load_missing.

* Removing _keys_to_ignore_on_missing from MT5 since the key is not used in the model

* Add fix-copies to MT5ModelTest
This commit is contained in:
JB (Don) 2024-02-01 10:53:49 +08:00 committed by GitHub
parent 7b2bd1fbbd
commit 0d26abdd3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1579 additions and 54 deletions

View File

@ -101,6 +101,10 @@ See [`T5TokenizerFast`] for all details.
[[autodoc]] MT5ForSequenceClassification
## MT5ForTokenClassification
[[autodoc]] MT5ForTokenClassification
## MT5ForQuestionAnswering
[[autodoc]] MT5ForQuestionAnswering

View File

@ -402,6 +402,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] T5ForSequenceClassification
- forward
## T5ForTokenClassification
[[autodoc]] T5ForTokenClassification
- forward
## T5ForQuestionAnswering
[[autodoc]] T5ForQuestionAnswering

View File

@ -100,6 +100,11 @@ Refer to [T5's documentation page](t5) for more tips, code examples and notebook
[[autodoc]] UMT5ForSequenceClassification
- forward
## UMT5ForTokenClassification
[[autodoc]] UMT5ForTokenClassification
- forward
## UMT5ForQuestionAnswering
[[autodoc]] UMT5ForQuestionAnswering

View File

@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip-->

View File

@ -2731,6 +2731,7 @@ else:
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
]
@ -3299,6 +3300,7 @@ else:
"T5ForConditionalGeneration",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
"T5Model",
"T5PreTrainedModel",
"load_tf_weights_in_t5",
@ -3370,6 +3372,7 @@ else:
"UMT5ForConditionalGeneration",
"UMT5ForQuestionAnswering",
"UMT5ForSequenceClassification",
"UMT5ForTokenClassification",
"UMT5Model",
"UMT5PreTrainedModel",
]
@ -7223,6 +7226,7 @@ if TYPE_CHECKING:
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
)
@ -7688,6 +7692,7 @@ if TYPE_CHECKING:
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
@ -7743,6 +7748,7 @@ if TYPE_CHECKING:
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5ForTokenClassification,
UMT5Model,
UMT5PreTrainedModel,
)

View File

@ -950,6 +950,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"),
("mt5", "MT5ForTokenClassification"),
("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("phi", "PhiForTokenClassification"),
@ -960,6 +961,8 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("roc_bert", "RoCBertForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("t5", "T5ForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),

View File

@ -52,6 +52,7 @@ else:
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
"MT5Stack",
@ -88,6 +89,7 @@ if TYPE_CHECKING:
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
MT5Stack,

View File

@ -71,6 +71,7 @@ class MT5Config(PretrainedConfig):
model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
def __init__(
self,
@ -97,15 +98,6 @@ class MT5Config(PretrainedConfig):
classifier_dropout=0.0,
**kwargs,
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
@ -139,17 +131,15 @@ class MT5Config(PretrainedConfig):
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):

View File

@ -32,6 +32,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
@ -54,6 +55,19 @@ _CONFIG_FOR_DOC = "MT5Config"
_CHECKPOINT_FOR_DOC = "mt5-small"
####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
MT5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/mt5-small",
"google/mt5-base",
"google/mt5-large",
"google/mt5-xl",
"google/mt5-xxl",
# See all mT5 models at https://huggingface.co/models?filter=mt5
]
PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.
@ -804,6 +818,10 @@ class MT5PreTrainedModel(PreTrainedModel):
if hasattr(module, "qa_outputs"):
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
module.qa_outputs.bias.data.zero_()
elif isinstance(module, MT5ForTokenClassification):
if hasattr(module, "classifier"):
module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.classifier.bias.data.zero_()
elif isinstance(module, MT5ClassificationHead):
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.dense, "bias") and module.dense.bias is not None:
@ -1334,7 +1352,6 @@ class MT5Model(MT5PreTrainedModel):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
@ -2158,6 +2175,80 @@ class MT5ForSequenceClassification(MT5PreTrainedModel):
)
@add_start_docstrings(
"""
MT5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
MT5_START_DOCSTRING,
)
class MT5ForTokenClassification(MT5PreTrainedModel):
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5
def __init__(self, config: MT5Config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = MT5EncoderModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits, outputs[2:-1])
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers

View File

@ -58,6 +58,7 @@ else:
"load_tf_weights_in_t5",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
]
try:
@ -119,6 +120,7 @@ if TYPE_CHECKING:
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,

View File

@ -33,6 +33,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
@ -832,6 +833,10 @@ class T5PreTrainedModel(PreTrainedModel):
if hasattr(module, "qa_outputs"):
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
module.qa_outputs.bias.data.zero_()
elif isinstance(module, T5ForTokenClassification):
if hasattr(module, "classifier"):
module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.classifier.bias.data.zero_()
elif isinstance(module, T5ClassificationHead):
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.dense, "bias") and module.dense.bias is not None:
@ -2118,6 +2123,78 @@ class T5ForSequenceClassification(T5PreTrainedModel):
)
@add_start_docstrings(
"""
T5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
T5_START_DOCSTRING,
)
class T5ForTokenClassification(T5PreTrainedModel):
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = T5EncoderModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits, outputs[2:-1])
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers

View File

@ -31,6 +31,7 @@ else:
"UMT5ForConditionalGeneration",
"UMT5ForQuestionAnswering",
"UMT5ForSequenceClassification",
"UMT5ForTokenClassification",
"UMT5Model",
"UMT5PreTrainedModel",
]
@ -49,6 +50,7 @@ if TYPE_CHECKING:
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5ForTokenClassification,
UMT5Model,
UMT5PreTrainedModel,
)

View File

@ -76,6 +76,7 @@ class UMT5Config(PretrainedConfig):
model_type = "umt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
def __init__(
self,
@ -102,15 +103,6 @@ class UMT5Config(PretrainedConfig):
classifier_dropout=0.0,
**kwargs,
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
@ -143,17 +135,15 @@ class UMT5Config(PretrainedConfig):
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
class UMT5OnnxConfig(OnnxSeq2SeqConfigWithPast):

View File

@ -30,6 +30,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -515,6 +516,10 @@ class UMT5PreTrainedModel(PreTrainedModel):
if hasattr(module, "qa_outputs"):
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
module.qa_outputs.bias.data.zero_()
elif isinstance(module, UMT5ForTokenClassification):
if hasattr(module, "classifier"):
module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.classifier.bias.data.zero_()
elif isinstance(module, UMT5ClassificationHead):
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.dense, "bias") and module.dense.bias is not None:
@ -941,7 +946,7 @@ class UMT5Model(UMT5PreTrainedModel):
>>> hidden_states = outputs.last_hidden_state
```"""
model_type = "uumt5"
model_type = "umt5"
config_class = UMT5Config
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
@ -1589,6 +1594,81 @@ class UMT5ForSequenceClassification(UMT5PreTrainedModel):
)
@add_start_docstrings(
"""
UMT5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
UMT5_START_DOCSTRING,
)
class UMT5ForTokenClassification(UMT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->UMT5
def __init__(self, config: UMT5Config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = UMT5EncoderModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->UMT5
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits, outputs[2:-1])
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
UMT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers

View File

@ -5724,6 +5724,13 @@ class MT5ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MT5ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MT5Model(metaclass=DummyObject):
_backends = ["torch"]
@ -7977,6 +7984,13 @@ class T5ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"])
class T5ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class T5Model(metaclass=DummyObject):
_backends = ["torch"]
@ -8216,6 +8230,13 @@ class UMT5ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"])
class UMT5ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UMT5Model(metaclass=DummyObject):
_backends = ["torch"]

File diff suppressed because it is too large Load Diff

View File

@ -52,6 +52,7 @@ if is_torch_available():
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5Tokenizer,
)
@ -586,9 +587,11 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if tokenizer_name is None:
return True
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@ -998,6 +1001,22 @@ class T5EncoderOnlyModelTester:
output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_with_token_classification_head(
self,
config,
input_ids,
attention_mask,
):
labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device)
model = T5ForTokenClassification(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels))
self.parent.assertEqual(outputs["loss"].size(), ())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -1013,11 +1032,18 @@ class T5EncoderOnlyModelTester:
return config, inputs_dict
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
class T5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (T5EncoderModel, T5ForTokenClassification) if is_torch_available() else ()
test_pruning = False
test_resize_embeddings = False
test_model_parallel = True
pipeline_model_mapping = (
{
"token-classification": T5ForTokenClassification,
}
if is_torch_available()
else {}
)
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
def setUp(self):
@ -1036,6 +1062,10 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_with_token_classification_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])

View File

@ -18,7 +18,7 @@ import pickle
import tempfile
import unittest
from transformers import T5Config, is_torch_available
from transformers import UMT5Config, is_torch_available
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.testing_utils import (
require_sentencepiece,
@ -30,6 +30,7 @@ from transformers.testing_utils import (
from transformers.utils import is_torch_fx_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -43,9 +44,11 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
UMT5EncoderModel,
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5ForTokenClassification,
UMT5Model,
)
@ -100,7 +103,7 @@ class UMT5ModelTester:
self.decoder_layers = decoder_layers
def get_large_model_config(self):
return T5Config.from_pretrained("google/umt5-base")
return UMT5Config.from_pretrained("google/umt5-base")
def prepare_inputs_dict(
self,
@ -160,7 +163,7 @@ class UMT5ModelTester:
return config, inputs_dict
def get_pipeline_config(self):
return T5Config(
return UMT5Config(
vocab_size=166, # t5 forces 100 extra tokens
d_model=self.hidden_size,
d_ff=self.d_ff,
@ -178,7 +181,7 @@ class UMT5ModelTester:
)
def get_config(self):
return T5Config(
return UMT5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
@ -556,6 +559,176 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pass
# Copied from tests.models.t5.test_modeling_t5.T5EncoderOnlyModelTester with T5->UMT5
class UMT5EncoderOnlyModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
# For common tests
use_attention_mask=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
is_training=False,
dropout_rate=0.1,
initializer_factor=0.002,
is_encoder_decoder=False,
eos_token_id=1,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
# For common tests
self.seq_length = self.encoder_seq_length
self.use_attention_mask = use_attention_mask
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.is_training = is_training
def get_large_model_config(self):
return UMT5Config.from_pretrained("t5-base")
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
config = UMT5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
attention_mask,
):
model = UMT5EncoderModel(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
result = model(input_ids=input_ids)
encoder_output = result.last_hidden_state
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
def create_and_check_model_fp16_forward(
self,
config,
input_ids,
attention_mask,
):
model = UMT5EncoderModel(config=config).to(torch_device).half().eval()
output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_with_token_classification_head(
self,
config,
input_ids,
attention_mask,
):
labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device)
model = UMT5ForTokenClassification(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels))
self.parent.assertEqual(outputs["loss"].size(), ())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
# Copied from tests.models.t5.test_modeling_t5.T5EncoderOnlyModelTest with T5->UMT5
class UMT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (UMT5EncoderModel, UMT5ForTokenClassification) if is_torch_available() else ()
test_pruning = False
test_resize_embeddings = False
test_model_parallel = True
pipeline_model_mapping = (
{
"token-classification": UMT5ForTokenClassification,
}
if is_torch_available()
else {}
)
all_parallelizable_model_classes = (UMT5EncoderModel,) if is_torch_available() else ()
def setUp(self):
self.model_tester = UMT5EncoderOnlyModelTester(self)
self.config_tester = ConfigTester(self, config_class=UMT5Config, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_with_token_classification_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
@require_torch
@require_sentencepiece
@require_tokenizers