fixup
This commit is contained in:
parent
df9e78377b
commit
c804b4bc6d
|
@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
COHERE_ATTENTION_CLASSES = {
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
"eager": GemmaAttention,
|
||||
"flash_attention_2": GemmaFlashAttention2,
|
||||
"sdpa": GemmaSdpaAttention,
|
||||
|
|
|
@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
COHERE_ATTENTION_CLASSES = {
|
||||
GEMMA_ATTENTION_CLASSES = {
|
||||
"eager": GemmaAttention,
|
||||
"flash_attention_2": GemmaFlashAttention2,
|
||||
"sdpa": GemmaSdpaAttention,
|
||||
|
|
|
@ -15,7 +15,6 @@ import argparse
|
|||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
@ -304,7 +303,9 @@ def write_model(
|
|||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Llama model.")
|
||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
|
|
|
@ -1,21 +1,26 @@
|
|||
import libcst as cst
|
||||
from libcst import matchers as m
|
||||
import importlib
|
||||
import re
|
||||
|
||||
import libcst as cst
|
||||
from libcst import matchers as m
|
||||
|
||||
|
||||
# Should we use the scope to figure out if classes are imported and inherited from
|
||||
# then go from here, instead of visiting the classes?
|
||||
|
||||
|
||||
def get_module_source_from_name(module_name: str) -> str:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
return f"Module {module_name} not found"
|
||||
|
||||
with open(spec.origin, 'r') as file:
|
||||
|
||||
with open(spec.origin, "r") as file:
|
||||
source_code = file.read()
|
||||
|
||||
|
||||
return source_code
|
||||
|
||||
from libcst import parse_module, CSTVisitor, CSTTransformer, ClassDef, Name
|
||||
|
||||
from libcst import ClassDef, CSTTransformer, CSTVisitor, Name
|
||||
|
||||
|
||||
class ClassFinder(CSTVisitor):
|
||||
|
@ -27,7 +32,6 @@ class ClassFinder(CSTVisitor):
|
|||
self.classes[node.name.value] = node
|
||||
|
||||
|
||||
|
||||
class ReplaceNameTransformer(CSTTransformer):
|
||||
def __init__(self, old_name, new_name):
|
||||
self.new_name = new_name
|
||||
|
@ -43,15 +47,15 @@ class ReplaceNameTransformer(CSTTransformer):
|
|||
return self.new_name.title()
|
||||
else:
|
||||
return self.new_name.lower()
|
||||
|
||||
return self.regex.sub(replace, text)
|
||||
|
||||
def leave_Name(self, original_node: Name, updated_node: Name) -> Name:
|
||||
# Replace 'Llama' with 'Cohere' in names
|
||||
updated_value = self.preserve_case_replace(updated_node.value)
|
||||
return updated_node.with_changes(value=updated_value)
|
||||
|
||||
|
||||
def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
|
||||
def find_classes_in_file(module, old_id="llama", new_id="gemma"):
|
||||
transformer = ReplaceNameTransformer(old_id, new_id)
|
||||
new_module = module.visit(transformer)
|
||||
|
||||
|
@ -59,9 +63,11 @@ def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
|
|||
new_module.visit(class_finder)
|
||||
return class_finder.classes, new_module
|
||||
|
||||
# Define a visitor to traverse the tree and find import statements
|
||||
class ImportVisitor(cst.CSTVisitor):
|
||||
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
def __init__(self, python_module):
|
||||
super().__init__()
|
||||
self.python_module = python_module
|
||||
self.transformers_imports = {}
|
||||
self.transformers_mapping = {}
|
||||
self.class_mapping = {}
|
||||
|
@ -69,13 +75,11 @@ class ImportVisitor(cst.CSTVisitor):
|
|||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
if m.matches(node.module, m.Attribute()):
|
||||
full_statement = self.python_module.code_for_node(node.module)
|
||||
full_statement = self.python_module.code_for_node(node.module)
|
||||
for imported_ in node.names:
|
||||
if "modeling_" in full_statement:
|
||||
print(f"resolving import from {full_statement}")
|
||||
if full_statement not in self.transformers_imports:
|
||||
source_code = get_module_source_from_name(full_statement)
|
||||
print(f"Source code found.")
|
||||
tree = cst.parse_module(source_code)
|
||||
self.transformers_imports[full_statement] = tree
|
||||
self.transformers_mapping[self.python_module.code_for_node(imported_.name)] = full_statement
|
||||
|
@ -84,17 +88,10 @@ class ImportVisitor(cst.CSTVisitor):
|
|||
if m.matches(node.value, m.Name()):
|
||||
parent_package = self.transformers_mapping.get(node.value.value, None)
|
||||
if parent_package:
|
||||
print(f" Model assignment. Finding Source code of {node.value.value} to replace the assignment")
|
||||
classes, renamed_module = find_classes_in_file(self.transformers_imports[parent_package])
|
||||
self.class_mapping[self.python_module.code_for_node(node)] = classes[node.targets[0].target.value]
|
||||
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
def __init__(self, mapping, python_module):
|
||||
super().__init__()
|
||||
self.mapping = mapping
|
||||
self.python_module = python_module
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node:cst.Assign, updated_node:cst.CSTNode):
|
||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||
match updated_node:
|
||||
# note: this is just a plain copy & paste of the pattern as seen in the CST
|
||||
case cst.SimpleStatementLine(
|
||||
|
@ -107,14 +104,14 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
):
|
||||
assign = self.python_module.code_for_node(original_node.body[0])
|
||||
node = original_node.body[0]
|
||||
if m.matches(node.value, m.Name()) and assign in self.mapping:
|
||||
return self.mapping[assign]
|
||||
if m.matches(node.value, m.Name()) and assign in self.class_mapping:
|
||||
return self.class_mapping[assign]
|
||||
return updated_node
|
||||
|
||||
def leave_ClassDef(self, original_node:cst.Assign, node):
|
||||
print(node.name)
|
||||
def leave_ClassDef(self, original_node: cst.Assign, node):
|
||||
return node
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse the Python file
|
||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
||||
|
@ -122,11 +119,11 @@ if __name__ == "__main__":
|
|||
module = cst.parse_module(code)
|
||||
# find_modeling_imports(code)
|
||||
# Use the visitor to find imports
|
||||
visitor = ImportVisitor(module)
|
||||
module.visit(visitor)
|
||||
# visitor = ImportVisitor(module)
|
||||
# module.visit(visitor)
|
||||
|
||||
transformers = DiffConverterTransformer(visitor.class_mapping, module)
|
||||
new_mod = module.visit(transformers)
|
||||
transformers = DiffConverterTransformer(module)
|
||||
new_mod = module.visit(transformers)
|
||||
with open("result.py", "w") as f:
|
||||
f.write(new_mod.code)
|
||||
exit(0)
|
||||
|
|
Loading…
Reference in New Issue