fixups
This commit is contained in:
parent
0faa82da98
commit
1836a758f8
|
@ -1,5 +1,4 @@
|
|||
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||
|
@ -48,7 +47,7 @@
|
|||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff.
|
||||
|
@ -72,7 +71,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
|
|
|
@ -23,15 +23,16 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
|
@ -109,6 +110,7 @@ class GemmaConfig(PretrainedConfig):
|
|||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
|
@ -161,6 +163,7 @@ class GemmaConfig(PretrainedConfig):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Example where we only want to overwrite the defaults of an init?
|
||||
class GemmaConfig(LlamaConfig):
|
||||
def __init__(
|
||||
|
@ -188,6 +191,7 @@ class GemmaConfig(LlamaConfig):
|
|||
):
|
||||
super().__init__(self)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
@ -407,6 +411,7 @@ class GemmaModel(LlamaModel):
|
|||
cache_position,
|
||||
)
|
||||
|
||||
|
||||
# Example where we ony modify the docstring and call super
|
||||
class GemmaForCausalLM(LlamaForCausalLM):
|
||||
def forward(
|
||||
|
@ -459,12 +464,13 @@ class GemmaForCausalLM(LlamaForCausalLM):
|
|||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
cache_position,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
|
||||
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class GemmaForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||
|
@ -48,7 +47,7 @@
|
|||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff.
|
||||
|
@ -137,6 +136,7 @@ def _get_unpad_data(attention_mask):
|
|||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||
|
@ -48,7 +47,7 @@
|
|||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff.
|
||||
|
|
|
@ -74,6 +74,7 @@ AUTO_GENERATED_MESSAGE = """
|
|||
|
||||
"""
|
||||
|
||||
|
||||
def get_module_source_from_name(module_name: str) -> str:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
|
@ -236,7 +237,17 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
|
|||
wrapper.visit(class_finder)
|
||||
return class_finder
|
||||
|
||||
DOCSTRING_NODE = m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString(value=m.MatchIfTrue(lambda value: re.search(r'\"\"\"[\s\S]*\"\"\"',value) is not None)))])
|
||||
|
||||
DOCSTRING_NODE = m.SimpleStatementLine(
|
||||
body=[
|
||||
m.Expr(
|
||||
value=m.SimpleString(
|
||||
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
@ -347,15 +358,15 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||
|
||||
# TODO here is where we merge stuff from super. We can choose to merge the docstring as well!
|
||||
# We could also check the docstring here
|
||||
original_methods = {f.name.value if hasattr(f,"name") else f: f for f in original_node.body.body }
|
||||
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
|
||||
|
||||
# Copy methods from original node to replacement node, preserving decorators
|
||||
updated_methods = {f.name.value if hasattr(f,"name") else f: f for f in updated_node.body.body }
|
||||
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
|
||||
end_meth = []
|
||||
for name, func in original_methods.items():
|
||||
if name in updated_methods:
|
||||
# Replace the method in the replacement class, preserving decorators
|
||||
func = func.with_changes(body=updated_methods[name].body, params = updated_methods[name].params )
|
||||
func = func.with_changes(body=updated_methods[name].body, params=updated_methods[name].params)
|
||||
end_meth.append(func)
|
||||
|
||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
|
|
Loading…
Reference in New Issue