This commit is contained in:
Arthur Zucker 2024-05-27 17:02:39 +02:00
parent b888fcdd1d
commit 80363e3fb7
5 changed files with 69 additions and 64 deletions

View File

@ -15,7 +15,6 @@
# limitations under the License.
from ...utils import (
is_flash_attn_2_available,
)
@ -32,6 +31,7 @@ if is_flash_attn_2_available():
class GemmaConfig(PreTrainedConfig):
model_type = "gemma"
def __init__(
self,
vocab_size=256000,

View File

@ -44,6 +44,7 @@ logger = logging.get_logger(__name__)
class GemmaConfig(PreTrainedConfig):
model_type = "gemma"
def __init__(
self,
vocab_size=256000,

View File

@ -810,7 +810,6 @@ GEMMA_INPUTS_DOCSTRING = r"""
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1028,11 +1027,11 @@ class GemmaModel(GemmaPreTrainedModel):
return causal_mask
_CONFIG_FOR_DOC = "GemmaConfig"
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)

View File

@ -565,12 +565,14 @@ def run_ruff(code):
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
def fix_ruff(code):
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
def stylify(code: str) -> str:
"""
Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
@ -758,9 +760,9 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No
else:
# not in the target --> add it
theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
name_mappings_1[
name_mappings_1[f"_ignored_new_block_{ignored_new_block_index}"] = (
f"_ignored_new_block_{ignored_new_block_index}"
] = f"_ignored_new_block_{ignored_new_block_index}"
)
del observed_code_blocks[name]
observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code

View File

@ -7,8 +7,7 @@ import libcst as cst
from check_copies import fix_ruff
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, ScopeProvider, PositionProvider
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
def get_module_source_from_name(module_name: str) -> str:
@ -22,10 +21,9 @@ def get_module_source_from_name(module_name: str) -> str:
return source_code
class ClassFinder(CSTVisitor):
"""A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
For example if the visited code has
For example if the visited code has
```python3
def init_value(): return 1
@ -36,7 +34,8 @@ class ClassFinder(CSTVisitor):
```
then the `class_dependency_mapping` should be: `{"LlamaModel":{"PreTrainedModel":{},"init_value":{}}, "init_value":{}}
"""
METADATA_DEPENDENCIES = (ParentNodeProvider,ScopeProvider,PositionProvider)
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module):
self.python_module = python_module
@ -45,19 +44,19 @@ class ClassFinder(CSTVisitor):
self.function_def = {} # def repeat_kv
self.assignments = {} # LLAMA_DOCSTRING
self.protected_imports = {} # if is_xxx_available()
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
def visit_ClassDef(self, node: ClassDef) -> None:
self.classes[node.name.value] = node
for k in node.bases: # deal with inheritance
for k in node.bases: # deal with inheritance
name = self.python_module.code_for_node(k)
self.class_dependency_mapping.update(
{
node.name.value: set(self.class_dependency_mapping.get(name, {name})) | self.class_dependency_mapping.get(node.name.value, set())
node.name.value: set(self.class_dependency_mapping.get(name, {name}))
| self.class_dependency_mapping.get(node.name.value, set())
}
)
def visit_SimpleStatementLine(self, node):
match node:
case cst.SimpleStatementLine(body=[cst.Assign(targets=[_], value=_)]):
@ -78,37 +77,37 @@ class ClassFinder(CSTVisitor):
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
def leave_Name(self, node):
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
dad = self.get_metadata(cst.metadata.ScopeProvider,node)
dad = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
print(f"Name:\t\t{node.value:<45} called in {dad._name_prefix}")
print(f"Name:\t\t{node.value:<45} called in {dad._name_prefix}")
name = dad._name_prefix.split(".")[0]
dep = set(self.class_dependency_mapping.get(node.value,set()))
dep |= set(self.class_dependency_mapping.get(name,{})) | set({node.value})
dep = set(self.class_dependency_mapping.get(node.value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({node.value})
self.class_dependency_mapping[name] = dep
def leave_Arg(self, node):
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
dad = self.get_metadata(ParentNodeProvider,node)
dad = self.get_metadata(ParentNodeProvider, node)
if m.matches(dad, m.ClassDef()) and dad.bases:
print(f"Arg:\t\t{node.value.value:<45} called in {dad.name.value}")
print(f"Arg:\t\t{node.value.value:<45} called in {dad.name.value}")
name = dad.name.value
dep = set(self.class_dependency_mapping.get(node.value.value,set()))
dep |= set(self.class_dependency_mapping.get(name,{})) | set({node.value.value})
self.class_dependency_mapping[name] = dep
dep = set(self.class_dependency_mapping.get(node.value.value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({node.value.value})
self.class_dependency_mapping[name] = dep
def leave_Dict(self, node):
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(dad,m.Assign(targets=[m.AssignTarget()])):
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
name = dad.targets[0].target.value
if name in self.assignments:
for k in node.elements:
if k.value.value in self.classes:
dep = set(self.class_dependency_mapping.get(k.value.value,set()))
dep |= self.class_dependency_mapping.get(name,set()) | set({k.value.value})
dep = set(self.class_dependency_mapping.get(k.value.value, set()))
dep |= self.class_dependency_mapping.get(name, set()) | set({k.value.value})
self.class_dependency_mapping[name] = dep
print(f"Dict:\t\t{k.value.value:<45} called in {name}")
print(f"Dict:\t\t{k.value.value:<45} called in {name}")
# Decorator: handle in leave_FunctionDef and leave_ClassDef instead
def leave_Decorator(self, node):
@ -122,8 +121,8 @@ class ClassFinder(CSTVisitor):
else:
name = dad.name.value
print(f"Decorator:\t{k.value.value:<45} called in {name}")
dep = set(self.class_dependency_mapping.get(k.value.value,set()))
dep |= self.class_dependency_mapping.get(name,set()) | set({k.value.value})
dep = set(self.class_dependency_mapping.get(k.value.value, set()))
dep |= self.class_dependency_mapping.get(name, set()) | set({k.value.value})
self.class_dependency_mapping[name] = dep
def leave_Module(self, node):
@ -131,11 +130,12 @@ class ClassFinder(CSTVisitor):
# now sort the class dependency_mapping based on the position of the nodes
self.class_start_line = {}
for id, node in self.global_nodes.items():
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider,node).start.line
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references"""
def __init__(self, old_name, new_name):
super().__init__()
self.new_name = new_name
@ -173,6 +173,7 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
wrapper.visit(class_finder)
return class_finder
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
@ -203,7 +204,11 @@ class SuperTransformer(cst.CSTTransformer):
if m.matches(
expr,
m.SimpleStatementLine(
body=[m.Expr(value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))))]
body=[
m.Expr(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
]
),
):
# Replace the SimpleStatementLine containing super().__init__() with the new body from func_to_body_mapping
@ -223,14 +228,13 @@ class SuperTransformer(cst.CSTTransformer):
new_body.append(expr)
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
@ -246,11 +250,9 @@ class SuperTransformer(cst.CSTTransformer):
return updated_node
def replace_call_to_super(class_finder:ClassFinder, updated_node:cst.ClassDef, class_name:str):
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
original_node = class_finder.classes[class_name]
original_methods = {
f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())
}
original_methods = {f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())}
# Copy methods from original node to replacement node, preserving decorators
updated_methods = {f.name.value: f for f in updated_node.body.body if m.matches(f, m.FunctionDef())}
@ -264,12 +266,15 @@ def replace_call_to_super(class_finder:ClassFinder, updated_node:cst.ClassDef, c
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[0] # get the indented block
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[
0
] # get the indented block
return original_node.with_changes(body=new_replacement_class.body)
class DiffConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,ScopeProvider,PositionProvider)
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module):
super().__init__()
@ -286,7 +291,7 @@ class DiffConverterTransformer(CSTTransformer):
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.new_body.append(node)
self.new_body.append(node)
return node
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
@ -298,12 +303,11 @@ class DiffConverterTransformer(CSTTransformer):
source_code = get_module_source_from_name(import_statement)
tree = cst.parse_module(source_code)
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
match updated_node:
case cst.SimpleStatementLine(body=[cst.Assign(targets=[_], value=_)]):
assign = self.python_module.code_for_node(original_node.body[0])
node = original_node.body[0]
@ -319,7 +323,7 @@ class DiffConverterTransformer(CSTTransformer):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.new_body.append(updated_node)
self.new_body.append(updated_node)
return updated_node
def leave_ClassDef(self, original_node, updated_node):
@ -329,23 +333,26 @@ class DiffConverterTransformer(CSTTransformer):
for super_class in bases:
old_name = re.findall(r"[A-Z][a-z0-9]*", super_class)[0].lower()
if super_class not in self.imported_mapping:
raise ImportError(f"{super_class} was not imported using `from transformers.models.{old_name}.modeling_{old_name} import {super_class}")
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
new_name = re.findall(r"[A-Z][a-z0-9]*", class_name)[0].lower()
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name], old_name, new_name
raise ImportError(
f"{super_class} was not imported using `from transformers.models.{old_name}.modeling_{old_name} import {super_class}"
)
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
new_name = re.findall(r"[A-Z][a-z0-9]*", class_name)[0].lower()
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
self.visited_module[super_file_name] = class_finder
self.class_mapping[class_name] = class_finder.classes[class_name]
else:
else:
class_finder = self.visited_module[super_file_name]
list_dependencies = {dep:class_finder.class_start_line.get(dep,1000)for dep in class_finder.class_dependency_mapping[class_name]}
for dependency, _ in sorted(list_dependencies.items(), key=lambda x:x[1]):
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping[class_name]
}
for dependency, _ in sorted(list_dependencies.items(), key=lambda x: x[1]):
node = class_finder.global_nodes.get(dependency, None)
# make sure the class is not re-defined by the diff file
# make sure the class is not re-defined by the diff file
if node is not None and node not in self.new_body:
if dependency not in self.class_mapping:
self.new_body.append(node)
@ -370,19 +377,15 @@ class DiffConverterTransformer(CSTTransformer):
if m.matches(parent_node, m.Module()):
self.new_body.append(node)
return node
def leave_Module(self, original_node: cst.Assign, node):
new_body = []
for visiter in self.visited_module.values():
new_body += list(visiter.imports.values())
self.config_body = list(visiter.imports.values()) + self.config_body
self.config_body = list(visiter.imports.values()) + self.config_body
return node.with_changes(body=[*new_body, *self.new_body])
def convert_file(diff_file):
# Parse the Python file
with open(diff_file, "r") as file:
@ -397,7 +400,7 @@ def convert_file(diff_file):
f.write(ruffed_code)
if hasattr(transformers, "config_body"):
config_module = cst.Module(body = [*transformers.config_body], header=new_mod.header)
config_module = cst.Module(body=[*transformers.config_body], header=new_mod.header)
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = fix_ruff(config_module.code)
f.write(ruffed_code)