From f0068b702c189d02a0a5142171aeee3f84b65929 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 28 May 2024 10:55:29 +0200 Subject: [PATCH] update --- src/transformers/models/gemma/configuration_gemma.py | 1 + src/transformers/models/gemma/diff_gemma.py | 4 ++++ src/transformers/models/gemma/modeling_gemma.py | 10 ++++++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 234ed60f22..4f5c772600 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -15,6 +15,7 @@ # limitations under the License. + from transformers import PretrainedConfig diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index 30eeb92267..4cc1c5a87f 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -28,6 +28,7 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, + LlamaForTokenClassification, apply_rotary_pos_emb, repeat_kv, ) @@ -406,3 +407,6 @@ class GemmaForCausalLM(LlamaForCausalLM): class GemmaForSequenceClassification(LlamaForSequenceClassification): pass + +class GemmaForTokenClassification(LlamaForTokenClassification): + pass \ No newline at end of file diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fe3aed1fbd..359bc4788d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,6 +30,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -275,6 +276,7 @@ class GemmaAttention(nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -286,8 +288,8 @@ class GemmaAttention(nn.Module): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -349,7 +351,6 @@ class GemmaFlashAttention2(GemmaAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -812,6 +813,7 @@ _CONFIG_FOR_DOC = "GemmaConfig" GEMMA_START_DOCSTRING, ) class GemmaModel(GemmaPreTrainedModel): + def __init__(self, config: GemmaConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -1031,6 +1033,7 @@ class GemmaModel(GemmaPreTrainedModel): class GemmaForCausalLM(GemmaPreTrainedModel): + def __init__(self, config): super().__init__(config) self.model = GemmaModel(config) @@ -1365,7 +1368,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): """, GEMMA_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config)