This commit is contained in:
Arthur Zucker 2024-05-18 10:35:27 +02:00
parent 65a00cefba
commit 292e573321
3 changed files with 7 additions and 15 deletions

View File

@ -247,7 +247,6 @@ class GemmaModel(LlamaModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@ -264,7 +263,7 @@ class GemmaModel(LlamaModel):
hidden_states = inputs_embeds
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
return super().forward(
None,
attention_mask,

View File

@ -21,6 +21,8 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -30,7 +32,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
@ -38,14 +39,11 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_gemma import GemmaConfig
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
@ -207,6 +205,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
""" PyTorch Gemma model."""
import math
@ -217,13 +217,6 @@ import torch.utils.checkpoint
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
@ -698,6 +691,7 @@ class GemmaSdpaAttention(GemmaAttention):
return attn_output, None, past_key_value
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
@ -853,7 +847,6 @@ class GemmaModel(GemmaPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"

View File

@ -238,7 +238,7 @@ class SuperTransformer(cst.CSTTransformer):
existing_body.append(stmt)
existing_nodes.add(stmt)
return existing_body
if m.matches(
updated_node.value,
m.Call(func=m.Attribute(value=m.Call(func=m.Name(value="super")), attr=m.Name("__init__"))),