fixes
This commit is contained in:
parent
b888fcdd1d
commit
80363e3fb7
|
@ -15,7 +15,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
)
|
||||
|
@ -32,6 +31,7 @@ if is_flash_attn_2_available():
|
|||
|
||||
class GemmaConfig(PreTrainedConfig):
|
||||
model_type = "gemma"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
|
|
|
@ -44,6 +44,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
class GemmaConfig(PreTrainedConfig):
|
||||
model_type = "gemma"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
|
|
|
@ -810,7 +810,6 @@ GEMMA_INPUTS_DOCSTRING = r"""
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1028,11 +1027,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
return causal_mask
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "GemmaConfig"
|
||||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
|
|
@ -565,12 +565,14 @@ def run_ruff(code):
|
|||
stdout, _ = process.communicate(input=code.encode())
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
def fix_ruff(code):
|
||||
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
|
||||
stdout, _ = process.communicate(input=code.encode())
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
def stylify(code: str) -> str:
|
||||
"""
|
||||
Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
|
||||
|
@ -758,9 +760,9 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No
|
|||
else:
|
||||
# not in the target --> add it
|
||||
theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||
name_mappings_1[
|
||||
name_mappings_1[f"_ignored_new_block_{ignored_new_block_index}"] = (
|
||||
f"_ignored_new_block_{ignored_new_block_index}"
|
||||
] = f"_ignored_new_block_{ignored_new_block_index}"
|
||||
)
|
||||
|
||||
del observed_code_blocks[name]
|
||||
observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||
|
|
|
@ -7,8 +7,7 @@ import libcst as cst
|
|||
from check_copies import fix_ruff
|
||||
from libcst import ClassDef, CSTTransformer, CSTVisitor
|
||||
from libcst import matchers as m
|
||||
from libcst.metadata import MetadataWrapper, ParentNodeProvider, ScopeProvider, PositionProvider
|
||||
|
||||
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
|
||||
|
||||
|
||||
def get_module_source_from_name(module_name: str) -> str:
|
||||
|
@ -22,10 +21,9 @@ def get_module_source_from_name(module_name: str) -> str:
|
|||
return source_code
|
||||
|
||||
|
||||
|
||||
class ClassFinder(CSTVisitor):
|
||||
"""A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
|
||||
For example if the visited code has
|
||||
For example if the visited code has
|
||||
```python3
|
||||
def init_value(): return 1
|
||||
|
||||
|
@ -36,7 +34,8 @@ class ClassFinder(CSTVisitor):
|
|||
```
|
||||
then the `class_dependency_mapping` should be: `{"LlamaModel":{"PreTrainedModel":{},"init_value":{}}, "init_value":{}}
|
||||
"""
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,ScopeProvider,PositionProvider)
|
||||
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
def __init__(self, python_module):
|
||||
self.python_module = python_module
|
||||
|
@ -45,19 +44,19 @@ class ClassFinder(CSTVisitor):
|
|||
self.function_def = {} # def repeat_kv
|
||||
self.assignments = {} # LLAMA_DOCSTRING
|
||||
self.protected_imports = {} # if is_xxx_available()
|
||||
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||
|
||||
def visit_ClassDef(self, node: ClassDef) -> None:
|
||||
self.classes[node.name.value] = node
|
||||
for k in node.bases: # deal with inheritance
|
||||
for k in node.bases: # deal with inheritance
|
||||
name = self.python_module.code_for_node(k)
|
||||
self.class_dependency_mapping.update(
|
||||
{
|
||||
node.name.value: set(self.class_dependency_mapping.get(name, {name})) | self.class_dependency_mapping.get(node.name.value, set())
|
||||
node.name.value: set(self.class_dependency_mapping.get(name, {name}))
|
||||
| self.class_dependency_mapping.get(node.name.value, set())
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def visit_SimpleStatementLine(self, node):
|
||||
match node:
|
||||
case cst.SimpleStatementLine(body=[cst.Assign(targets=[_], value=_)]):
|
||||
|
@ -78,37 +77,37 @@ class ClassFinder(CSTVisitor):
|
|||
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
||||
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
|
||||
|
||||
def leave_Name(self, node):
|
||||
def leave_Name(self, node):
|
||||
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
|
||||
dad = self.get_metadata(cst.metadata.ScopeProvider,node)
|
||||
dad = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
|
||||
print(f"Name:\t\t{node.value:<45} called in {dad._name_prefix}")
|
||||
print(f"Name:\t\t{node.value:<45} called in {dad._name_prefix}")
|
||||
name = dad._name_prefix.split(".")[0]
|
||||
dep = set(self.class_dependency_mapping.get(node.value,set()))
|
||||
dep |= set(self.class_dependency_mapping.get(name,{})) | set({node.value})
|
||||
dep = set(self.class_dependency_mapping.get(node.value, set()))
|
||||
dep |= set(self.class_dependency_mapping.get(name, {})) | set({node.value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
|
||||
def leave_Arg(self, node):
|
||||
def leave_Arg(self, node):
|
||||
if m.matches(node.value, m.Name()):
|
||||
dad = self.get_metadata(ParentNodeProvider,node)
|
||||
dad = self.get_metadata(ParentNodeProvider, node)
|
||||
if m.matches(dad, m.ClassDef()) and dad.bases:
|
||||
print(f"Arg:\t\t{node.value.value:<45} called in {dad.name.value}")
|
||||
print(f"Arg:\t\t{node.value.value:<45} called in {dad.name.value}")
|
||||
name = dad.name.value
|
||||
dep = set(self.class_dependency_mapping.get(node.value.value,set()))
|
||||
dep |= set(self.class_dependency_mapping.get(name,{})) | set({node.value.value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
dep = set(self.class_dependency_mapping.get(node.value.value, set()))
|
||||
dep |= set(self.class_dependency_mapping.get(name, {})) | set({node.value.value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
|
||||
def leave_Dict(self, node):
|
||||
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
if m.matches(dad,m.Assign(targets=[m.AssignTarget()])):
|
||||
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
|
||||
name = dad.targets[0].target.value
|
||||
if name in self.assignments:
|
||||
for k in node.elements:
|
||||
if k.value.value in self.classes:
|
||||
dep = set(self.class_dependency_mapping.get(k.value.value,set()))
|
||||
dep |= self.class_dependency_mapping.get(name,set()) | set({k.value.value})
|
||||
dep = set(self.class_dependency_mapping.get(k.value.value, set()))
|
||||
dep |= self.class_dependency_mapping.get(name, set()) | set({k.value.value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
print(f"Dict:\t\t{k.value.value:<45} called in {name}")
|
||||
print(f"Dict:\t\t{k.value.value:<45} called in {name}")
|
||||
|
||||
# Decorator: handle in leave_FunctionDef and leave_ClassDef instead
|
||||
def leave_Decorator(self, node):
|
||||
|
@ -122,8 +121,8 @@ class ClassFinder(CSTVisitor):
|
|||
else:
|
||||
name = dad.name.value
|
||||
print(f"Decorator:\t{k.value.value:<45} called in {name}")
|
||||
dep = set(self.class_dependency_mapping.get(k.value.value,set()))
|
||||
dep |= self.class_dependency_mapping.get(name,set()) | set({k.value.value})
|
||||
dep = set(self.class_dependency_mapping.get(k.value.value, set()))
|
||||
dep |= self.class_dependency_mapping.get(name, set()) | set({k.value.value})
|
||||
self.class_dependency_mapping[name] = dep
|
||||
|
||||
def leave_Module(self, node):
|
||||
|
@ -131,11 +130,12 @@ class ClassFinder(CSTVisitor):
|
|||
# now sort the class dependency_mapping based on the position of the nodes
|
||||
self.class_start_line = {}
|
||||
for id, node in self.global_nodes.items():
|
||||
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider,node).start.line
|
||||
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
|
||||
|
||||
|
||||
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references"""
|
||||
|
||||
def __init__(self, old_name, new_name):
|
||||
super().__init__()
|
||||
self.new_name = new_name
|
||||
|
@ -173,6 +173,7 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
|
|||
wrapper.visit(class_finder)
|
||||
return class_finder
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
|
@ -203,7 +204,11 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
if m.matches(
|
||||
expr,
|
||||
m.SimpleStatementLine(
|
||||
body=[m.Expr(value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))))]
|
||||
body=[
|
||||
m.Expr(
|
||||
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||
)
|
||||
]
|
||||
),
|
||||
):
|
||||
# Replace the SimpleStatementLine containing super().__init__() with the new body from func_to_body_mapping
|
||||
|
@ -223,14 +228,13 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
new_body.append(expr)
|
||||
return node.with_changes(body=new_body)
|
||||
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
if updated_node.name.value in self.updated_methods:
|
||||
name = updated_node.name.value
|
||||
name = updated_node.name.value
|
||||
new_body = self.replace_super_calls(updated_node.body, name)
|
||||
return updated_node.with_changes(body=new_body)
|
||||
return updated_node
|
||||
|
||||
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
|
||||
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
|
||||
func_def = self.get_metadata(ParentNodeProvider, original_node)
|
||||
|
@ -246,11 +250,9 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
|
||||
def replace_call_to_super(class_finder:ClassFinder, updated_node:cst.ClassDef, class_name:str):
|
||||
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
|
||||
original_node = class_finder.classes[class_name]
|
||||
original_methods = {
|
||||
f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())
|
||||
}
|
||||
original_methods = {f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())}
|
||||
|
||||
# Copy methods from original node to replacement node, preserving decorators
|
||||
updated_methods = {f.name.value: f for f in updated_node.body.body if m.matches(f, m.FunctionDef())}
|
||||
|
@ -264,12 +266,15 @@ def replace_call_to_super(class_finder:ClassFinder, updated_node:cst.ClassDef, c
|
|||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[0] # get the indented block
|
||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[
|
||||
0
|
||||
] # get the indented block
|
||||
|
||||
return original_node.with_changes(body=new_replacement_class.body)
|
||||
|
||||
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,ScopeProvider,PositionProvider)
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
def __init__(self, python_module):
|
||||
super().__init__()
|
||||
|
@ -286,7 +291,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.new_body.append(node)
|
||||
self.new_body.append(node)
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
|
@ -298,12 +303,11 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
source_code = get_module_source_from_name(import_statement)
|
||||
tree = cst.parse_module(source_code)
|
||||
self.transformers_imports[import_statement] = tree
|
||||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||
match updated_node:
|
||||
|
||||
case cst.SimpleStatementLine(body=[cst.Assign(targets=[_], value=_)]):
|
||||
assign = self.python_module.code_for_node(original_node.body[0])
|
||||
node = original_node.body[0]
|
||||
|
@ -319,7 +323,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.new_body.append(updated_node)
|
||||
self.new_body.append(updated_node)
|
||||
return updated_node
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
|
@ -329,23 +333,26 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
for super_class in bases:
|
||||
old_name = re.findall(r"[A-Z][a-z0-9]*", super_class)[0].lower()
|
||||
if super_class not in self.imported_mapping:
|
||||
raise ImportError(f"{super_class} was not imported using `from transformers.models.{old_name}.modeling_{old_name} import {super_class}")
|
||||
|
||||
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
|
||||
new_name = re.findall(r"[A-Z][a-z0-9]*", class_name)[0].lower()
|
||||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name], old_name, new_name
|
||||
raise ImportError(
|
||||
f"{super_class} was not imported using `from transformers.models.{old_name}.modeling_{old_name} import {super_class}"
|
||||
)
|
||||
|
||||
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
|
||||
new_name = re.findall(r"[A-Z][a-z0-9]*", class_name)[0].lower()
|
||||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
|
||||
self.visited_module[super_file_name] = class_finder
|
||||
self.class_mapping[class_name] = class_finder.classes[class_name]
|
||||
else:
|
||||
else:
|
||||
class_finder = self.visited_module[super_file_name]
|
||||
|
||||
list_dependencies = {dep:class_finder.class_start_line.get(dep,1000)for dep in class_finder.class_dependency_mapping[class_name]}
|
||||
for dependency, _ in sorted(list_dependencies.items(), key=lambda x:x[1]):
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping[class_name]
|
||||
}
|
||||
for dependency, _ in sorted(list_dependencies.items(), key=lambda x: x[1]):
|
||||
node = class_finder.global_nodes.get(dependency, None)
|
||||
# make sure the class is not re-defined by the diff file
|
||||
# make sure the class is not re-defined by the diff file
|
||||
if node is not None and node not in self.new_body:
|
||||
if dependency not in self.class_mapping:
|
||||
self.new_body.append(node)
|
||||
|
@ -370,19 +377,15 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
if m.matches(parent_node, m.Module()):
|
||||
self.new_body.append(node)
|
||||
return node
|
||||
|
||||
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
new_body = []
|
||||
for visiter in self.visited_module.values():
|
||||
new_body += list(visiter.imports.values())
|
||||
self.config_body = list(visiter.imports.values()) + self.config_body
|
||||
self.config_body = list(visiter.imports.values()) + self.config_body
|
||||
return node.with_changes(body=[*new_body, *self.new_body])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def convert_file(diff_file):
|
||||
# Parse the Python file
|
||||
with open(diff_file, "r") as file:
|
||||
|
@ -397,7 +400,7 @@ def convert_file(diff_file):
|
|||
f.write(ruffed_code)
|
||||
|
||||
if hasattr(transformers, "config_body"):
|
||||
config_module = cst.Module(body = [*transformers.config_body], header=new_mod.header)
|
||||
config_module = cst.Module(body=[*transformers.config_body], header=new_mod.header)
|
||||
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
|
||||
ruffed_code = fix_ruff(config_module.code)
|
||||
f.write(ruffed_code)
|
||||
|
|
Loading…
Reference in New Issue