This commit is contained in:
Arthur Zucker 2024-05-28 16:31:09 +02:00
parent 0faa82da98
commit 1836a758f8
5 changed files with 32 additions and 18 deletions

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒ # ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██ # ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██ # ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░ # ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>. # This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. # the file from the diff.
@ -72,7 +71,6 @@
# limitations under the License. # limitations under the License.
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@ -23,15 +23,16 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaModel,
LlamaForTokenClassification, LlamaForTokenClassification,
LlamaModel,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.models.llama.configuration_llama import LlamaConfig
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache from ...cache_utils import Cache
from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_outputs import CausalLMOutputWithPast
@ -109,6 +110,7 @@ class GemmaConfig(PretrainedConfig):
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "gemma" model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
@ -161,6 +163,7 @@ class GemmaConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Example where we only want to overwrite the defaults of an init? # Example where we only want to overwrite the defaults of an init?
class GemmaConfig(LlamaConfig): class GemmaConfig(LlamaConfig):
def __init__( def __init__(
@ -188,6 +191,7 @@ class GemmaConfig(LlamaConfig):
): ):
super().__init__(self) super().__init__(self)
class GemmaRMSNorm(nn.Module): class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@ -407,6 +411,7 @@ class GemmaModel(LlamaModel):
cache_position, cache_position,
) )
# Example where we ony modify the docstring and call super # Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM): class GemmaForCausalLM(LlamaForCausalLM):
def forward( def forward(
@ -459,12 +464,13 @@ class GemmaForCausalLM(LlamaForCausalLM):
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
cache_position, cache_position,
) )
class GemmaForSequenceClassification(LlamaForSequenceClassification): class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass pass
class GemmaForTokenClassification(LlamaForTokenClassification): class GemmaForTokenClassification(LlamaForTokenClassification):
pass pass

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒ # ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██ # ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██ # ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░ # ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>. # This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. # the file from the diff.
@ -137,6 +136,7 @@ def _get_unpad_data(attention_mask):
max_seqlen_in_batch, max_seqlen_in_batch,
) )
class GemmaRMSNorm(nn.Module): class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()

View File

@ -1,5 +1,4 @@
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒ # ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██ # ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██ # ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
@ -48,7 +47,7 @@
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██ # ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░ # ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
# This file was automatically generated from <path_to_diff_file.py>. # This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. # the file from the diff.

View File

@ -74,6 +74,7 @@ AUTO_GENERATED_MESSAGE = """
""" """
def get_module_source_from_name(module_name: str) -> str: def get_module_source_from_name(module_name: str) -> str:
spec = importlib.util.find_spec(module_name) spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None: if spec is None or spec.origin is None:
@ -236,7 +237,17 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
wrapper.visit(class_finder) wrapper.visit(class_finder)
return class_finder return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString(value=m.MatchIfTrue(lambda value: re.search(r'\"\"\"[\s\S]*\"\"\"',value) is not None)))])
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
class SuperTransformer(cst.CSTTransformer): class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,) METADATA_DEPENDENCIES = (ParentNodeProvider,)
@ -347,15 +358,15 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
# TODO here is where we merge stuff from super. We can choose to merge the docstring as well! # TODO here is where we merge stuff from super. We can choose to merge the docstring as well!
# We could also check the docstring here # We could also check the docstring here
original_methods = {f.name.value if hasattr(f,"name") else f: f for f in original_node.body.body } original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
# Copy methods from original node to replacement node, preserving decorators # Copy methods from original node to replacement node, preserving decorators
updated_methods = {f.name.value if hasattr(f,"name") else f: f for f in updated_node.body.body } updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
end_meth = [] end_meth = []
for name, func in original_methods.items(): for name, func in original_methods.items():
if name in updated_methods: if name in updated_methods:
# Replace the method in the replacement class, preserving decorators # Replace the method in the replacement class, preserving decorators
func = func.with_changes(body=updated_methods[name].body, params = updated_methods[name].params ) func = func.with_changes(body=updated_methods[name].body, params=updated_methods[name].params)
end_meth.append(func) end_meth.append(func)
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))