fixup
This commit is contained in:
parent
df9e78377b
commit
c804b4bc6d
|
@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
COHERE_ATTENTION_CLASSES = {
|
GEMMA_ATTENTION_CLASSES = {
|
||||||
"eager": GemmaAttention,
|
"eager": GemmaAttention,
|
||||||
"flash_attention_2": GemmaFlashAttention2,
|
"flash_attention_2": GemmaFlashAttention2,
|
||||||
"sdpa": GemmaSdpaAttention,
|
"sdpa": GemmaSdpaAttention,
|
||||||
|
|
|
@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
COHERE_ATTENTION_CLASSES = {
|
GEMMA_ATTENTION_CLASSES = {
|
||||||
"eager": GemmaAttention,
|
"eager": GemmaAttention,
|
||||||
"flash_attention_2": GemmaFlashAttention2,
|
"flash_attention_2": GemmaFlashAttention2,
|
||||||
"sdpa": GemmaSdpaAttention,
|
"sdpa": GemmaSdpaAttention,
|
||||||
|
|
|
@ -15,7 +15,6 @@ import argparse
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -304,7 +303,9 @@ def write_model(
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print("Loading the checkpoint in a Llama model.")
|
print("Loading the checkpoint in a Llama model.")
|
||||||
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||||
|
) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
|
||||||
|
|
||||||
|
|
||||||
class Llama3Converter(TikTokenConverter):
|
class Llama3Converter(TikTokenConverter):
|
||||||
|
|
|
@ -1,21 +1,26 @@
|
||||||
import libcst as cst
|
|
||||||
from libcst import matchers as m
|
|
||||||
import importlib
|
import importlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
from libcst import matchers as m
|
||||||
|
|
||||||
|
|
||||||
# Should we use the scope to figure out if classes are imported and inherited from
|
# Should we use the scope to figure out if classes are imported and inherited from
|
||||||
# then go from here, instead of visiting the classes?
|
# then go from here, instead of visiting the classes?
|
||||||
|
|
||||||
|
|
||||||
def get_module_source_from_name(module_name: str) -> str:
|
def get_module_source_from_name(module_name: str) -> str:
|
||||||
spec = importlib.util.find_spec(module_name)
|
spec = importlib.util.find_spec(module_name)
|
||||||
if spec is None or spec.origin is None:
|
if spec is None or spec.origin is None:
|
||||||
return f"Module {module_name} not found"
|
return f"Module {module_name} not found"
|
||||||
|
|
||||||
with open(spec.origin, 'r') as file:
|
with open(spec.origin, "r") as file:
|
||||||
source_code = file.read()
|
source_code = file.read()
|
||||||
|
|
||||||
return source_code
|
return source_code
|
||||||
|
|
||||||
from libcst import parse_module, CSTVisitor, CSTTransformer, ClassDef, Name
|
|
||||||
|
from libcst import ClassDef, CSTTransformer, CSTVisitor, Name
|
||||||
|
|
||||||
|
|
||||||
class ClassFinder(CSTVisitor):
|
class ClassFinder(CSTVisitor):
|
||||||
|
@ -27,7 +32,6 @@ class ClassFinder(CSTVisitor):
|
||||||
self.classes[node.name.value] = node
|
self.classes[node.name.value] = node
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceNameTransformer(CSTTransformer):
|
class ReplaceNameTransformer(CSTTransformer):
|
||||||
def __init__(self, old_name, new_name):
|
def __init__(self, old_name, new_name):
|
||||||
self.new_name = new_name
|
self.new_name = new_name
|
||||||
|
@ -43,15 +47,15 @@ class ReplaceNameTransformer(CSTTransformer):
|
||||||
return self.new_name.title()
|
return self.new_name.title()
|
||||||
else:
|
else:
|
||||||
return self.new_name.lower()
|
return self.new_name.lower()
|
||||||
|
|
||||||
return self.regex.sub(replace, text)
|
return self.regex.sub(replace, text)
|
||||||
|
|
||||||
def leave_Name(self, original_node: Name, updated_node: Name) -> Name:
|
def leave_Name(self, original_node: Name, updated_node: Name) -> Name:
|
||||||
# Replace 'Llama' with 'Cohere' in names
|
|
||||||
updated_value = self.preserve_case_replace(updated_node.value)
|
updated_value = self.preserve_case_replace(updated_node.value)
|
||||||
return updated_node.with_changes(value=updated_value)
|
return updated_node.with_changes(value=updated_value)
|
||||||
|
|
||||||
|
|
||||||
def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
|
def find_classes_in_file(module, old_id="llama", new_id="gemma"):
|
||||||
transformer = ReplaceNameTransformer(old_id, new_id)
|
transformer = ReplaceNameTransformer(old_id, new_id)
|
||||||
new_module = module.visit(transformer)
|
new_module = module.visit(transformer)
|
||||||
|
|
||||||
|
@ -59,9 +63,11 @@ def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
|
||||||
new_module.visit(class_finder)
|
new_module.visit(class_finder)
|
||||||
return class_finder.classes, new_module
|
return class_finder.classes, new_module
|
||||||
|
|
||||||
# Define a visitor to traverse the tree and find import statements
|
|
||||||
class ImportVisitor(cst.CSTVisitor):
|
class DiffConverterTransformer(CSTTransformer):
|
||||||
def __init__(self, python_module):
|
def __init__(self, python_module):
|
||||||
|
super().__init__()
|
||||||
|
self.python_module = python_module
|
||||||
self.transformers_imports = {}
|
self.transformers_imports = {}
|
||||||
self.transformers_mapping = {}
|
self.transformers_mapping = {}
|
||||||
self.class_mapping = {}
|
self.class_mapping = {}
|
||||||
|
@ -72,10 +78,8 @@ class ImportVisitor(cst.CSTVisitor):
|
||||||
full_statement = self.python_module.code_for_node(node.module)
|
full_statement = self.python_module.code_for_node(node.module)
|
||||||
for imported_ in node.names:
|
for imported_ in node.names:
|
||||||
if "modeling_" in full_statement:
|
if "modeling_" in full_statement:
|
||||||
print(f"resolving import from {full_statement}")
|
|
||||||
if full_statement not in self.transformers_imports:
|
if full_statement not in self.transformers_imports:
|
||||||
source_code = get_module_source_from_name(full_statement)
|
source_code = get_module_source_from_name(full_statement)
|
||||||
print(f"Source code found.")
|
|
||||||
tree = cst.parse_module(source_code)
|
tree = cst.parse_module(source_code)
|
||||||
self.transformers_imports[full_statement] = tree
|
self.transformers_imports[full_statement] = tree
|
||||||
self.transformers_mapping[self.python_module.code_for_node(imported_.name)] = full_statement
|
self.transformers_mapping[self.python_module.code_for_node(imported_.name)] = full_statement
|
||||||
|
@ -84,17 +88,10 @@ class ImportVisitor(cst.CSTVisitor):
|
||||||
if m.matches(node.value, m.Name()):
|
if m.matches(node.value, m.Name()):
|
||||||
parent_package = self.transformers_mapping.get(node.value.value, None)
|
parent_package = self.transformers_mapping.get(node.value.value, None)
|
||||||
if parent_package:
|
if parent_package:
|
||||||
print(f" Model assignment. Finding Source code of {node.value.value} to replace the assignment")
|
|
||||||
classes, renamed_module = find_classes_in_file(self.transformers_imports[parent_package])
|
classes, renamed_module = find_classes_in_file(self.transformers_imports[parent_package])
|
||||||
self.class_mapping[self.python_module.code_for_node(node)] = classes[node.targets[0].target.value]
|
self.class_mapping[self.python_module.code_for_node(node)] = classes[node.targets[0].target.value]
|
||||||
|
|
||||||
class DiffConverterTransformer(CSTTransformer):
|
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||||
def __init__(self, mapping, python_module):
|
|
||||||
super().__init__()
|
|
||||||
self.mapping = mapping
|
|
||||||
self.python_module = python_module
|
|
||||||
|
|
||||||
def leave_SimpleStatementLine(self, original_node:cst.Assign, updated_node:cst.CSTNode):
|
|
||||||
match updated_node:
|
match updated_node:
|
||||||
# note: this is just a plain copy & paste of the pattern as seen in the CST
|
# note: this is just a plain copy & paste of the pattern as seen in the CST
|
||||||
case cst.SimpleStatementLine(
|
case cst.SimpleStatementLine(
|
||||||
|
@ -107,14 +104,14 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
):
|
):
|
||||||
assign = self.python_module.code_for_node(original_node.body[0])
|
assign = self.python_module.code_for_node(original_node.body[0])
|
||||||
node = original_node.body[0]
|
node = original_node.body[0]
|
||||||
if m.matches(node.value, m.Name()) and assign in self.mapping:
|
if m.matches(node.value, m.Name()) and assign in self.class_mapping:
|
||||||
return self.mapping[assign]
|
return self.class_mapping[assign]
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
def leave_ClassDef(self, original_node:cst.Assign, node):
|
def leave_ClassDef(self, original_node: cst.Assign, node):
|
||||||
print(node.name)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse the Python file
|
# Parse the Python file
|
||||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
||||||
|
@ -122,10 +119,10 @@ if __name__ == "__main__":
|
||||||
module = cst.parse_module(code)
|
module = cst.parse_module(code)
|
||||||
# find_modeling_imports(code)
|
# find_modeling_imports(code)
|
||||||
# Use the visitor to find imports
|
# Use the visitor to find imports
|
||||||
visitor = ImportVisitor(module)
|
# visitor = ImportVisitor(module)
|
||||||
module.visit(visitor)
|
# module.visit(visitor)
|
||||||
|
|
||||||
transformers = DiffConverterTransformer(visitor.class_mapping, module)
|
transformers = DiffConverterTransformer(module)
|
||||||
new_mod = module.visit(transformers)
|
new_mod = module.visit(transformers)
|
||||||
with open("result.py", "w") as f:
|
with open("result.py", "w") as f:
|
||||||
f.write(new_mod.code)
|
f.write(new_mod.code)
|
||||||
|
|
Loading…
Reference in New Issue