deal with duplicates

This commit is contained in:
Arthur Zucker 2024-05-18 10:35:16 +02:00
parent e3be54cf25
commit 65a00cefba
3 changed files with 48 additions and 8 deletions

View File

@ -247,11 +247,24 @@ class GemmaModel(LlamaModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
return super().forward(
None,
attention_mask,

View File

@ -21,8 +21,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -32,6 +30,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
@ -39,11 +38,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_gemma import GemmaConfig
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
@ -205,8 +207,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
""" PyTorch Gemma model."""
import math
@ -215,9 +215,15 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
@ -692,7 +698,6 @@ class GemmaSdpaAttention(GemmaAttention):
return attn_output, None, past_key_value
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
@ -848,8 +853,20 @@ class GemmaModel(GemmaPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer

View File

@ -223,12 +223,22 @@ class SuperTransformer(cst.CSTTransformer):
]
),
):
# TODO here we have the body, we can most probably remove the duplicates!
new_body.extend(self.original_methods[func_name].body.body)
new_body = self.update_body(new_body, self.original_methods[func_name].body.body)
else:
new_body.append(expr)
return node.with_changes(body=new_body)
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
"""
existing_nodes = {node for node in existing_body if isinstance(node, cst.CSTNode)}
for stmt in new_statements:
if isinstance(stmt, cst.CSTNode) and stmt not in existing_nodes:
existing_body.append(stmt)
existing_nodes.add(stmt)
return existing_body
if m.matches(
updated_node.value,
m.Call(func=m.Attribute(value=m.Call(func=m.Name(value="super")), attr=m.Name("__init__"))),