deal with duplicates
This commit is contained in:
parent
e3be54cf25
commit
65a00cefba
|
@ -247,11 +247,24 @@ class GemmaModel(LlamaModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
return super().forward(
|
||||
None,
|
||||
attention_mask,
|
||||
|
|
|
@ -21,8 +21,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
|
@ -32,6 +30,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
|
@ -39,11 +38,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_gemma import GemmaConfig
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -205,8 +207,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
""" PyTorch Gemma model."""
|
||||
|
||||
import math
|
||||
|
@ -215,9 +215,15 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
@ -692,7 +698,6 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
"eager": GemmaAttention,
|
||||
"flash_attention_2": GemmaFlashAttention2,
|
||||
|
@ -848,8 +853,20 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
|
|
@ -223,12 +223,22 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
]
|
||||
),
|
||||
):
|
||||
# TODO here we have the body, we can most probably remove the duplicates!
|
||||
new_body.extend(self.original_methods[func_name].body.body)
|
||||
new_body = self.update_body(new_body, self.original_methods[func_name].body.body)
|
||||
else:
|
||||
new_body.append(expr)
|
||||
return node.with_changes(body=new_body)
|
||||
|
||||
def update_body(self, existing_body, new_statements):
|
||||
"""
|
||||
Helper method to update the body by removing duplicates before adding new statements.
|
||||
"""
|
||||
existing_nodes = {node for node in existing_body if isinstance(node, cst.CSTNode)}
|
||||
for stmt in new_statements:
|
||||
if isinstance(stmt, cst.CSTNode) and stmt not in existing_nodes:
|
||||
existing_body.append(stmt)
|
||||
existing_nodes.add(stmt)
|
||||
return existing_body
|
||||
|
||||
if m.matches(
|
||||
updated_node.value,
|
||||
m.Call(func=m.Attribute(value=m.Call(func=m.Name(value="super")), attr=m.Name("__init__"))),
|
||||
|
|
Loading…
Reference in New Issue