This commit is contained in:
Arthur Zucker 2024-05-16 18:13:46 +02:00
parent df9e78377b
commit c804b4bc6d
4 changed files with 32 additions and 34 deletions

View File

@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
return attn_output, None, past_key_value
COHERE_ATTENTION_CLASSES = {
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
"sdpa": GemmaSdpaAttention,

View File

@ -505,7 +505,7 @@ class GemmaSdpaAttention(GemmaAttention):
return attn_output, None, past_key_value
COHERE_ATTENTION_CLASSES = {
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
"sdpa": GemmaSdpaAttention,

View File

@ -15,7 +15,6 @@ import argparse
import gc
import json
import os
import shutil
import warnings
import torch
@ -304,7 +303,9 @@ def write_model(
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
model = LlamaForCausalLM.from_pretrained(
tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
) # Avoid saving this as part of the config. del model.config._name_or_path model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path)
class Llama3Converter(TikTokenConverter):

View File

@ -1,21 +1,26 @@
import libcst as cst
from libcst import matchers as m
import importlib
import re
import libcst as cst
from libcst import matchers as m
# Should we use the scope to figure out if classes are imported and inherited from
# then go from here, instead of visiting the classes?
def get_module_source_from_name(module_name: str) -> str:
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, 'r') as file:
with open(spec.origin, "r") as file:
source_code = file.read()
return source_code
from libcst import parse_module, CSTVisitor, CSTTransformer, ClassDef, Name
from libcst import ClassDef, CSTTransformer, CSTVisitor, Name
class ClassFinder(CSTVisitor):
@ -27,7 +32,6 @@ class ClassFinder(CSTVisitor):
self.classes[node.name.value] = node
class ReplaceNameTransformer(CSTTransformer):
def __init__(self, old_name, new_name):
self.new_name = new_name
@ -43,15 +47,15 @@ class ReplaceNameTransformer(CSTTransformer):
return self.new_name.title()
else:
return self.new_name.lower()
return self.regex.sub(replace, text)
def leave_Name(self, original_node: Name, updated_node: Name) -> Name:
# Replace 'Llama' with 'Cohere' in names
updated_value = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=updated_value)
def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
def find_classes_in_file(module, old_id="llama", new_id="gemma"):
transformer = ReplaceNameTransformer(old_id, new_id)
new_module = module.visit(transformer)
@ -59,9 +63,11 @@ def find_classes_in_file(module, old_id = "llama", new_id="gemma"):
new_module.visit(class_finder)
return class_finder.classes, new_module
# Define a visitor to traverse the tree and find import statements
class ImportVisitor(cst.CSTVisitor):
class DiffConverterTransformer(CSTTransformer):
def __init__(self, python_module):
super().__init__()
self.python_module = python_module
self.transformers_imports = {}
self.transformers_mapping = {}
self.class_mapping = {}
@ -72,10 +78,8 @@ class ImportVisitor(cst.CSTVisitor):
full_statement = self.python_module.code_for_node(node.module)
for imported_ in node.names:
if "modeling_" in full_statement:
print(f"resolving import from {full_statement}")
if full_statement not in self.transformers_imports:
source_code = get_module_source_from_name(full_statement)
print(f"Source code found.")
tree = cst.parse_module(source_code)
self.transformers_imports[full_statement] = tree
self.transformers_mapping[self.python_module.code_for_node(imported_.name)] = full_statement
@ -84,17 +88,10 @@ class ImportVisitor(cst.CSTVisitor):
if m.matches(node.value, m.Name()):
parent_package = self.transformers_mapping.get(node.value.value, None)
if parent_package:
print(f" Model assignment. Finding Source code of {node.value.value} to replace the assignment")
classes, renamed_module = find_classes_in_file(self.transformers_imports[parent_package])
self.class_mapping[self.python_module.code_for_node(node)] = classes[node.targets[0].target.value]
class DiffConverterTransformer(CSTTransformer):
def __init__(self, mapping, python_module):
super().__init__()
self.mapping = mapping
self.python_module = python_module
def leave_SimpleStatementLine(self, original_node:cst.Assign, updated_node:cst.CSTNode):
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
match updated_node:
# note: this is just a plain copy & paste of the pattern as seen in the CST
case cst.SimpleStatementLine(
@ -107,14 +104,14 @@ class DiffConverterTransformer(CSTTransformer):
):
assign = self.python_module.code_for_node(original_node.body[0])
node = original_node.body[0]
if m.matches(node.value, m.Name()) and assign in self.mapping:
return self.mapping[assign]
if m.matches(node.value, m.Name()) and assign in self.class_mapping:
return self.class_mapping[assign]
return updated_node
def leave_ClassDef(self, original_node:cst.Assign, node):
print(node.name)
def leave_ClassDef(self, original_node: cst.Assign, node):
return node
if __name__ == "__main__":
# Parse the Python file
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
@ -122,10 +119,10 @@ if __name__ == "__main__":
module = cst.parse_module(code)
# find_modeling_imports(code)
# Use the visitor to find imports
visitor = ImportVisitor(module)
module.visit(visitor)
# visitor = ImportVisitor(module)
# module.visit(visitor)
transformers = DiffConverterTransformer(visitor.class_mapping, module)
transformers = DiffConverterTransformer(module)
new_mod = module.visit(transformers)
with open("result.py", "w") as f:
f.write(new_mod.code)