This commit is contained in:
Arthur Zucker 2024-05-28 10:55:29 +02:00
parent df19157424
commit f0068b702c
3 changed files with 11 additions and 4 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
from transformers import PretrainedConfig

View File

@ -28,6 +28,7 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaForTokenClassification,
apply_rotary_pos_emb,
repeat_kv,
)
@ -406,3 +407,6 @@ class GemmaForCausalLM(LlamaForCausalLM):
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
class GemmaForTokenClassification(LlamaForTokenClassification):
pass

View File

@ -30,6 +30,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@ -275,6 +276,7 @@ class GemmaAttention(nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -286,8 +288,8 @@ class GemmaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@ -349,7 +351,6 @@ class GemmaFlashAttention2(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
@ -812,6 +813,7 @@ _CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1031,6 +1033,7 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
@ -1365,7 +1368,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)