Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)
* Add MistralForTokenClassification * Add tests and docs * Add token classification for Mixtral and Qwen2 * Save llma for token classification draft * Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2 * Formatting * Add token classification support for Qwen2Moe model * Add dropout layer to each ForTokenClassification model * Add copied from in tests * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Propagate suggested changes * Style --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
481a957814
commit
07bf2dff78
|
@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [
|
|||
[[autodoc]] GemmaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## GemmaForTokenClassification
|
||||
|
||||
[[autodoc]] GemmaForTokenClassification
|
||||
- forward
|
||||
|
||||
## FlaxGemmaModel
|
||||
|
||||
[[autodoc]] FlaxGemmaModel
|
||||
|
|
|
@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||
[[autodoc]] LlamaForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## LlamaForTokenClassification
|
||||
|
||||
[[autodoc]] LlamaForTokenClassification
|
||||
- forward
|
||||
|
||||
## FlaxLlamaModel
|
||||
|
||||
[[autodoc]] FlaxLlamaModel
|
||||
|
|
|
@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||
[[autodoc]] MistralForSequenceClassification
|
||||
- forward
|
||||
|
||||
## MistralForTokenClassification
|
||||
|
||||
[[autodoc]] MistralForTokenClassification
|
||||
- forward
|
||||
|
||||
## FlaxMistralModel
|
||||
|
||||
[[autodoc]] FlaxMistralModel
|
||||
|
|
|
@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||
|
||||
[[autodoc]] MixtralForSequenceClassification
|
||||
- forward
|
||||
|
||||
## MixtralForTokenClassification
|
||||
|
||||
[[autodoc]] MixtralForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T
|
|||
|
||||
[[autodoc]] PersimmonForSequenceClassification
|
||||
- forward
|
||||
|
||||
## PersimmonForTokenClassification
|
||||
|
||||
[[autodoc]] PersimmonForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen
|
|||
|
||||
[[autodoc]] Qwen2ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen2ForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen2ForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
|
|||
|
||||
[[autodoc]] Qwen2MoeForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen2MoeForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen2MoeForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below:
|
|||
|
||||
[[autodoc]] StableLmForSequenceClassification
|
||||
- forward
|
||||
|
||||
## StableLmForTokenClassification
|
||||
|
||||
[[autodoc]] StableLmForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu
|
|||
|
||||
[[autodoc]] Starcoder2ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Starcoder2ForTokenClassification
|
||||
|
||||
[[autodoc]] Starcoder2ForTokenClassification
|
||||
- forward
|
||||
|
|
|
@ -2031,6 +2031,7 @@ else:
|
|||
[
|
||||
"GemmaForCausalLM",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
"GemmaModel",
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
|
@ -2288,6 +2289,7 @@ else:
|
|||
"LlamaForCausalLM",
|
||||
"LlamaForQuestionAnswering",
|
||||
"LlamaForSequenceClassification",
|
||||
"LlamaForTokenClassification",
|
||||
"LlamaModel",
|
||||
"LlamaPreTrainedModel",
|
||||
]
|
||||
|
@ -2435,12 +2437,19 @@ else:
|
|||
[
|
||||
"MistralForCausalLM",
|
||||
"MistralForSequenceClassification",
|
||||
"MistralForTokenClassification",
|
||||
"MistralModel",
|
||||
"MistralPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mixtral"].extend(
|
||||
["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"]
|
||||
[
|
||||
"MixtralForCausalLM",
|
||||
"MixtralForSequenceClassification",
|
||||
"MixtralForTokenClassification",
|
||||
"MixtralModel",
|
||||
"MixtralPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mobilebert"].extend(
|
||||
[
|
||||
|
@ -2714,6 +2723,7 @@ else:
|
|||
[
|
||||
"PersimmonForCausalLM",
|
||||
"PersimmonForSequenceClassification",
|
||||
"PersimmonForTokenClassification",
|
||||
"PersimmonModel",
|
||||
"PersimmonPreTrainedModel",
|
||||
]
|
||||
|
@ -2810,6 +2820,7 @@ else:
|
|||
[
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2ForSequenceClassification",
|
||||
"Qwen2ForTokenClassification",
|
||||
"Qwen2Model",
|
||||
"Qwen2PreTrainedModel",
|
||||
]
|
||||
|
@ -2818,6 +2829,7 @@ else:
|
|||
[
|
||||
"Qwen2MoeForCausalLM",
|
||||
"Qwen2MoeForSequenceClassification",
|
||||
"Qwen2MoeForTokenClassification",
|
||||
"Qwen2MoeModel",
|
||||
"Qwen2MoePreTrainedModel",
|
||||
]
|
||||
|
@ -3066,6 +3078,7 @@ else:
|
|||
[
|
||||
"StableLmForCausalLM",
|
||||
"StableLmForSequenceClassification",
|
||||
"StableLmForTokenClassification",
|
||||
"StableLmModel",
|
||||
"StableLmPreTrainedModel",
|
||||
]
|
||||
|
@ -3074,6 +3087,7 @@ else:
|
|||
[
|
||||
"Starcoder2ForCausalLM",
|
||||
"Starcoder2ForSequenceClassification",
|
||||
"Starcoder2ForTokenClassification",
|
||||
"Starcoder2Model",
|
||||
"Starcoder2PreTrainedModel",
|
||||
]
|
||||
|
@ -6489,6 +6503,7 @@ if TYPE_CHECKING:
|
|||
from .models.gemma import (
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaPreTrainedModel,
|
||||
)
|
||||
|
@ -6686,6 +6701,7 @@ if TYPE_CHECKING:
|
|||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
|
@ -6801,12 +6817,14 @@ if TYPE_CHECKING:
|
|||
from .models.mistral import (
|
||||
MistralForCausalLM,
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
MistralPreTrainedModel,
|
||||
)
|
||||
from .models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralForTokenClassification,
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
|
@ -7025,6 +7043,7 @@ if TYPE_CHECKING:
|
|||
from .models.persimmon import (
|
||||
PersimmonForCausalLM,
|
||||
PersimmonForSequenceClassification,
|
||||
PersimmonForTokenClassification,
|
||||
PersimmonModel,
|
||||
PersimmonPreTrainedModel,
|
||||
)
|
||||
|
@ -7099,12 +7118,14 @@ if TYPE_CHECKING:
|
|||
from .models.qwen2 import (
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2ForTokenClassification,
|
||||
Qwen2Model,
|
||||
Qwen2PreTrainedModel,
|
||||
)
|
||||
from .models.qwen2_moe import (
|
||||
Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForSequenceClassification,
|
||||
Qwen2MoeForTokenClassification,
|
||||
Qwen2MoeModel,
|
||||
Qwen2MoePreTrainedModel,
|
||||
)
|
||||
|
@ -7306,12 +7327,14 @@ if TYPE_CHECKING:
|
|||
from .models.stablelm import (
|
||||
StableLmForCausalLM,
|
||||
StableLmForSequenceClassification,
|
||||
StableLmForTokenClassification,
|
||||
StableLmModel,
|
||||
StableLmPreTrainedModel,
|
||||
)
|
||||
from .models.starcoder2 import (
|
||||
Starcoder2ForCausalLM,
|
||||
Starcoder2ForSequenceClassification,
|
||||
Starcoder2ForTokenClassification,
|
||||
Starcoder2Model,
|
||||
Starcoder2PreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -1038,6 +1038,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
("flaubert", "FlaubertForTokenClassification"),
|
||||
("fnet", "FNetForTokenClassification"),
|
||||
("funnel", "FunnelForTokenClassification"),
|
||||
("gemma", "GemmaForTokenClassification"),
|
||||
("gpt-sw3", "GPT2ForTokenClassification"),
|
||||
("gpt2", "GPT2ForTokenClassification"),
|
||||
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
||||
|
@ -1048,11 +1049,14 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||
("lilt", "LiltForTokenClassification"),
|
||||
("llama", "LlamaForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("luke", "LukeForTokenClassification"),
|
||||
("markuplm", "MarkupLMForTokenClassification"),
|
||||
("mega", "MegaForTokenClassification"),
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mistral", "MistralForTokenClassification"),
|
||||
("mixtral", "MixtralForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
("mpnet", "MPNetForTokenClassification"),
|
||||
("mpt", "MptForTokenClassification"),
|
||||
|
@ -1060,15 +1064,20 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
("mt5", "MT5ForTokenClassification"),
|
||||
("nezha", "NezhaForTokenClassification"),
|
||||
("nystromformer", "NystromformerForTokenClassification"),
|
||||
("persimmon", "PersimmonForTokenClassification"),
|
||||
("phi", "PhiForTokenClassification"),
|
||||
("phi3", "Phi3ForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("qwen2", "Qwen2ForTokenClassification"),
|
||||
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
||||
("rembert", "RemBertForTokenClassification"),
|
||||
("roberta", "RobertaForTokenClassification"),
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
||||
("roc_bert", "RoCBertForTokenClassification"),
|
||||
("roformer", "RoFormerForTokenClassification"),
|
||||
("squeezebert", "SqueezeBertForTokenClassification"),
|
||||
("stablelm", "StableLmForTokenClassification"),
|
||||
("starcoder2", "Starcoder2ForTokenClassification"),
|
||||
("t5", "T5ForTokenClassification"),
|
||||
("umt5", "UMT5ForTokenClassification"),
|
||||
("xlm", "XLMForTokenClassification"),
|
||||
|
|
|
@ -55,6 +55,7 @@ else:
|
|||
"GemmaModel",
|
||||
"GemmaPreTrainedModel",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
|
@ -98,6 +99,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_gemma import (
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -30,7 +30,12 @@ from ...modeling_attn_mask_utils import (
|
|||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from ...utils import (
|
||||
|
@ -1346,3 +1351,88 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -55,6 +55,7 @@ else:
|
|||
"LlamaPreTrainedModel",
|
||||
"LlamaForSequenceClassification",
|
||||
"LlamaForQuestionAnswering",
|
||||
"LlamaForTokenClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
|
@ -95,6 +96,7 @@ if TYPE_CHECKING:
|
|||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
|||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
|
@ -1516,3 +1517,87 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -32,6 +32,7 @@ else:
|
|||
"MistralModel",
|
||||
"MistralPreTrainedModel",
|
||||
"MistralForSequenceClassification",
|
||||
"MistralForTokenClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
|
@ -59,6 +60,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_mistral import (
|
||||
MistralForCausalLM,
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
MistralPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
|
@ -1366,3 +1371,88 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
MISTRAL_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
|
||||
class MistralForTokenClassification(MistralPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = MistralModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ else:
|
|||
"MixtralModel",
|
||||
"MixtralPreTrainedModel",
|
||||
"MixtralForSequenceClassification",
|
||||
"MixtralForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralForTokenClassification,
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -38,6 +38,7 @@ from ...modeling_outputs import (
|
|||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||
|
@ -1582,3 +1583,88 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
MIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
||||
class MixtralForTokenClassification(MixtralPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = MixtralModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ else:
|
|||
"PersimmonModel",
|
||||
"PersimmonPreTrainedModel",
|
||||
"PersimmonForSequenceClassification",
|
||||
"PersimmonForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_persimmon import (
|
||||
PersimmonForCausalLM,
|
||||
PersimmonForSequenceClassification,
|
||||
PersimmonForTokenClassification,
|
||||
PersimmonModel,
|
||||
PersimmonPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_persimmon import PersimmonConfig
|
||||
|
@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
PERSIMMON_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
|
||||
class PersimmonForTokenClassification(PersimmonPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = PersimmonModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -45,6 +45,7 @@ else:
|
|||
"Qwen2Model",
|
||||
"Qwen2PreTrainedModel",
|
||||
"Qwen2ForSequenceClassification",
|
||||
"Qwen2ForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -69,6 +70,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_qwen2 import (
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2ForTokenClassification,
|
||||
Qwen2Model,
|
||||
Qwen2PreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
|
@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
QWEN2_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
|
||||
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Qwen2Model(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ else:
|
|||
"Qwen2MoeModel",
|
||||
"Qwen2MoePreTrainedModel",
|
||||
"Qwen2MoeForSequenceClassification",
|
||||
"Qwen2MoeForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_qwen2_moe import (
|
||||
Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForSequenceClassification,
|
||||
Qwen2MoeForTokenClassification,
|
||||
Qwen2MoeModel,
|
||||
Qwen2MoePreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
|
@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
QWEN2MOE_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
|
||||
class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Qwen2MoeModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ else:
|
|||
"StableLmModel",
|
||||
"StableLmPreTrainedModel",
|
||||
"StableLmForSequenceClassification",
|
||||
"StableLmForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_stablelm import (
|
||||
StableLmForCausalLM,
|
||||
StableLmForSequenceClassification,
|
||||
StableLmForTokenClassification,
|
||||
StableLmModel,
|
||||
StableLmPreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
|
@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
STABLELM_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM
|
||||
class StableLmForTokenClassification(StableLmPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = StableLmModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ else:
|
|||
"Starcoder2Model",
|
||||
"Starcoder2PreTrainedModel",
|
||||
"Starcoder2ForSequenceClassification",
|
||||
"Starcoder2ForTokenClassification",
|
||||
]
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||
from .modeling_starcoder2 import (
|
||||
Starcoder2ForCausalLM,
|
||||
Starcoder2ForSequenceClassification,
|
||||
Starcoder2ForTokenClassification,
|
||||
Starcoder2Model,
|
||||
Starcoder2PreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
|
@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
STARCODER2_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
|
||||
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Starcoder2Model(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GemmaForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GemmaModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LlamaForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LlamaModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MistralForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MistralModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MixtralForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MixtralModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PersimmonForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PersimmonModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen2ForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen2MoeForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Qwen2MoeModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StableLmForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StableLmModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Starcoder2ForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Starcoder2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
|
|
@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
|
||||
from transformers import (
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class GemmaModelTester:
|
||||
|
@ -284,12 +290,17 @@ class GemmaModelTester:
|
|||
|
||||
@require_torch
|
||||
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GemmaModel,
|
||||
"text-classification": GemmaForSequenceClassification,
|
||||
"token-classification": GemmaForTokenClassification,
|
||||
"text-generation": GemmaForCausalLM,
|
||||
"zero-shot": GemmaForSequenceClassification,
|
||||
}
|
||||
|
@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
|
||||
def test_Gemma_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = GemmaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -47,6 +47,7 @@ if is_torch_available():
|
|||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
|
@ -286,7 +287,13 @@ class LlamaModelTester:
|
|||
@require_torch
|
||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
||||
(
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
"text-generation": LlamaForCausalLM,
|
||||
"zero-shot": LlamaForSequenceClassification,
|
||||
"question-answering": LlamaForQuestionAnswering,
|
||||
"token-classification": LlamaForTokenClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_llama_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = LlamaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -46,6 +46,7 @@ if is_torch_available():
|
|||
from transformers import (
|
||||
MistralForCausalLM,
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
)
|
||||
|
||||
|
@ -288,13 +289,16 @@ class MistralModelTester:
|
|||
@require_torch
|
||||
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else ()
|
||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MistralModel,
|
||||
"text-classification": MistralForSequenceClassification,
|
||||
"token-classification": MistralForTokenClassification,
|
||||
"text-generation": MistralForCausalLM,
|
||||
"zero-shot": MistralForSequenceClassification,
|
||||
}
|
||||
|
@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
|
||||
def test_Mistral_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = MistralForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Mistral buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
|
||||
from transformers import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralForTokenClassification,
|
||||
MixtralModel,
|
||||
)
|
||||
|
||||
|
||||
class MixtralModelTester:
|
||||
|
@ -287,13 +292,16 @@ class MixtralModelTester:
|
|||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
||||
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else ()
|
||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MixtralModel,
|
||||
"text-classification": MixtralForSequenceClassification,
|
||||
"token-classification": MixtralForTokenClassification,
|
||||
"text-generation": MixtralForCausalLM,
|
||||
"zero-shot": MixtralForSequenceClassification,
|
||||
}
|
||||
|
@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
|
||||
def test_Mixtral_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = MixtralForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -44,6 +44,7 @@ if is_torch_available():
|
|||
AutoTokenizer,
|
||||
PersimmonForCausalLM,
|
||||
PersimmonForSequenceClassification,
|
||||
PersimmonForTokenClassification,
|
||||
PersimmonModel,
|
||||
)
|
||||
from transformers.models.persimmon.modeling_persimmon import (
|
||||
|
@ -283,12 +284,15 @@ class PersimmonModelTester:
|
|||
@require_torch
|
||||
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else ()
|
||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": PersimmonModel,
|
||||
"text-classification": PersimmonForSequenceClassification,
|
||||
"token-classification": PersimmonForTokenClassification,
|
||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||
# "text-generation": PersimmonForCausalLM,
|
||||
# "zero-shot": PersimmonForSequenceClassification,
|
||||
|
@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
|
||||
def test_persimmon_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = PersimmonForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
|
|
|
@ -45,6 +45,7 @@ if is_torch_available():
|
|||
from transformers import (
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2ForTokenClassification,
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
|
@ -299,12 +300,17 @@ class Qwen2ModelTester:
|
|||
@require_torch
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
||||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen2Model,
|
||||
"text-classification": Qwen2ForSequenceClassification,
|
||||
"token-classification": Qwen2ForTokenClassification,
|
||||
"text-generation": Qwen2ForCausalLM,
|
||||
"zero-shot": Qwen2ForSequenceClassification,
|
||||
}
|
||||
|
@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
|
||||
def test_Qwen2_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen2ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -45,6 +45,7 @@ if is_torch_available():
|
|||
from transformers import (
|
||||
Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForSequenceClassification,
|
||||
Qwen2MoeForTokenClassification,
|
||||
Qwen2MoeModel,
|
||||
)
|
||||
|
||||
|
@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
|
|||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else ()
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen2MoeModel,
|
||||
"text-classification": Qwen2MoeForSequenceClassification,
|
||||
"token-classification": Qwen2MoeForTokenClassification,
|
||||
"text-generation": Qwen2MoeForCausalLM,
|
||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
||||
}
|
||||
|
@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
|
||||
def test_Qwen2Moe_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen2MoeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
|
@ -43,6 +43,7 @@ if is_torch_available():
|
|||
AutoTokenizer,
|
||||
StableLmForCausalLM,
|
||||
StableLmForSequenceClassification,
|
||||
StableLmForTokenClassification,
|
||||
StableLmModel,
|
||||
)
|
||||
from transformers.models.stablelm.modeling_stablelm import (
|
||||
|
@ -287,12 +288,15 @@ class StableLmModelTester:
|
|||
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
|
||||
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else ()
|
||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": StableLmModel,
|
||||
"text-classification": StableLmForSequenceClassification,
|
||||
"token-classification": StableLmForTokenClassification,
|
||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||
# "text-generation": StableLmForCausalLM,
|
||||
# "zero-shot": StableLmForSequenceClassification,
|
||||
|
@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
|
||||
def test_stablelm_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = StableLmForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
|
|
|
@ -43,6 +43,7 @@ if is_torch_available():
|
|||
AutoTokenizer,
|
||||
Starcoder2ForCausalLM,
|
||||
Starcoder2ForSequenceClassification,
|
||||
Starcoder2ForTokenClassification,
|
||||
Starcoder2Model,
|
||||
)
|
||||
|
||||
|
@ -290,13 +291,16 @@ class Starcoder2ModelTester:
|
|||
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
|
||||
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else ()
|
||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Starcoder2Model,
|
||||
"text-classification": Starcoder2ForSequenceClassification,
|
||||
"token-classification": Starcoder2ForTokenClassification,
|
||||
"text-generation": Starcoder2ForCausalLM,
|
||||
"zero-shot": Starcoder2ForSequenceClassification,
|
||||
}
|
||||
|
@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
|
||||
def test_Starcoder2_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Starcoder2ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue