This commit is contained in:
Arthur Zucker 2024-05-28 16:31:09 +02:00
parent 0faa82da98
commit 1836a758f8
5 changed files with 32 additions and 18 deletions

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff.
@ -72,7 +71,6 @@
# limitations under the License.
from transformers import PretrainedConfig

View File

@ -23,15 +23,16 @@ import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaForTokenClassification,
LlamaModel,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.models.llama.configuration_llama import LlamaConfig
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import CausalLMOutputWithPast
@ -109,6 +110,7 @@ class GemmaConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
@ -161,6 +163,7 @@ class GemmaConfig(PretrainedConfig):
**kwargs,
)
# Example where we only want to overwrite the defaults of an init?
class GemmaConfig(LlamaConfig):
def __init__(
@ -188,6 +191,7 @@ class GemmaConfig(LlamaConfig):
):
super().__init__(self)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -407,6 +411,7 @@ class GemmaModel(LlamaModel):
cache_position,
)
# Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM):
def forward(
@ -459,12 +464,13 @@ class GemmaForCausalLM(LlamaForCausalLM):
output_attentions,
output_hidden_states,
return_dict,
cache_position,
cache_position,
)
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
class GemmaForTokenClassification(LlamaForTokenClassification):
pass
pass

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff.
@ -137,6 +136,7 @@ def _get_unpad_data(attention_mask):
max_seqlen_in_batch,
)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff.

View File

@ -74,6 +74,7 @@ AUTO_GENERATED_MESSAGE = """
"""
def get_module_source_from_name(module_name: str) -> str:
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
@ -236,7 +237,17 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
wrapper.visit(class_finder)
return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString(value=m.MatchIfTrue(lambda value: re.search(r'\"\"\"[\s\S]*\"\"\"',value) is not None)))])
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
@ -347,15 +358,15 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
# TODO here is where we merge stuff from super. We can choose to merge the docstring as well!
# We could also check the docstring here
original_methods = {f.name.value if hasattr(f,"name") else f: f for f in original_node.body.body }
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
# Copy methods from original node to replacement node, preserving decorators
updated_methods = {f.name.value if hasattr(f,"name") else f: f for f in updated_node.body.body }
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
end_meth = []
for name, func in original_methods.items():
if name in updated_methods:
# Replace the method in the replacement class, preserving decorators
func = func.with_changes(body=updated_methods[name].body, params = updated_methods[name].params )
func = func.with_changes(body=updated_methods[name].body, params=updated_methods[name].params)
end_meth.append(func)
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))