This commit is contained in:
Arthur Zucker 2024-05-29 08:06:25 +02:00
parent 42f640fba8
commit ac0dc69bb2
2 changed files with 45 additions and 50 deletions

View File

@ -559,15 +559,11 @@ def get_indent(code: str) -> str:
return ""
def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
def fix_ruff(code):
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
def run_ruff(code, check=False):
if check:
command =["ruff", "check", "-", "--fix", "--exit-zero"]
else:
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 the HuggingFace Inc. team. All rights reserved.
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,9 +18,8 @@ import importlib
import re
import libcst as cst
from check_copies import fix_ruff
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from check_copies import run_ruff
from libcst import ClassDef, CSTTransformer, CSTVisitor, matchers
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
@ -92,21 +91,21 @@ class ClassFinder(CSTVisitor):
)
def visit_SimpleStatementLine(self, node):
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()):
if matchers.matches(node, matchers.SimpleStatementLine(body=[matchers.Assign()])) and matchers.matches(self.get_metadata(cst.metadata.ParentNodeProvider, node), matchers.Module()):
self.assignments[node.body[0].targets[0].target.value] = node
if m.matches(node, m.SimpleStatement(body=[m.Import()])):
if matchers.matches(node, matchers.SimpleStatement(body=[matchers.Import()])):
self.imports[node.body[0].names] = node
if m.matches(node, m.SimpleStatement(body=[m.ImportFrom()])):
if matchers.matches(node, matchers.SimpleStatement(body=[matchers.ImportFrom()])):
self.imports[node.body[0].names] = node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
if matchers.matches(parent_node, matchers.Module()):
self.function_def[node.name.value] = node
def leave_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
if matchers.matches(stmt, matchers.SimpleStatementLine(body=[matchers.ImportFrom() | matchers.Import()])):
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
def leave_Name(self, node):
@ -120,9 +119,9 @@ class ClassFinder(CSTVisitor):
self.class_dependency_mapping[name] = dep
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
if matchers.matches(node.value, matchers.Name()):
dad = self.get_metadata(ParentNodeProvider, node)
if m.matches(dad, m.ClassDef()) and dad.bases:
if matchers.matches(dad, matchers.ClassDef()) and dad.bases:
logger.info(f"Arg:\t\t{node.value.value:<45} called in {dad.name.value}")
name = dad.name.value
dep = set(self.class_dependency_mapping.get(node.value.value, set()))
@ -131,7 +130,7 @@ class ClassFinder(CSTVisitor):
def leave_Dict(self, node):
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
if matchers.matches(dad, matchers.Assign(targets=[matchers.AssignTarget()])):
name = dad.targets[0].target.value
if name in self.assignments:
for k in node.elements:
@ -165,7 +164,7 @@ class ClassFinder(CSTVisitor):
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
class ReplaceNameTransformer(matchers.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references"""
def __init__(self, old_name, new_name):
@ -188,7 +187,7 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return self.regex.sub(replace, text)
@m.leave(m.Name() | m.SimpleString() | m.Comment())
@matchers.leave(matchers.Name() | matchers.SimpleString() | matchers.Comment())
def replace_name(self, original_node, updated_node):
update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update)
@ -206,11 +205,11 @@ def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(
DOCSTRING_NODE = matchers.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
matchers.Expr(
value=matchers.SimpleString(
value=matchers.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
@ -235,7 +234,7 @@ class SuperTransformer(cst.CSTTransformer):
}
for stmt in existing_body:
if self.python_module.code_for_node(stmt).strip() not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
if matchers.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
print("Oh docstring")
continue
de_duplicated_new_body.append(stmt)
@ -248,27 +247,27 @@ class SuperTransformer(cst.CSTTransformer):
new_body = []
self.has_docstring = False
for expr in node.body:
if m.matches(node.body[0], DOCSTRING_NODE):
if matchers.matches(node.body[0], DOCSTRING_NODE):
self.has_docstring = True
if m.matches(
if matchers.matches(
expr,
m.SimpleStatementLine(
matchers.SimpleStatementLine(
body=[
m.Expr(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
matchers.Expr(
value=matchers.Call(func=matchers.Attribute(value=matchers.Call(func=matchers.Name("super")), attr=matchers.Name(func_name)))
)
]
),
):
# Replace the SimpleStatementLine containing super().__init__() with the new body from func_to_body_mapping
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
elif m.matches(
elif matchers.matches(
expr,
m.SimpleStatementLine(
matchers.SimpleStatementLine(
body=[
m.Return(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
matchers.Return(
value=matchers.Call(func=matchers.Attribute(value=matchers.Call(func=matchers.Name("super")), attr=matchers.Name(func_name)))
)
]
),
@ -287,9 +286,9 @@ class SuperTransformer(cst.CSTTransformer):
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
if matchers.matches(updated_node.value, matchers.Call(func=matchers.Attribute(attr=matchers.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
if matchers.matched(func_def, matchers.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
@ -365,7 +364,7 @@ class DiffConverterTransformer(CSTTransformer):
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if matchers.matches(parent_node, matchers.Module()):
self.new_body.append(node)
return node
@ -376,7 +375,7 @@ class DiffConverterTransformer(CSTTransformer):
3. Add this import to `self.transformers_imports` as visited to not parse it twice
"""
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
if matchers.matches(node.module, matchers.Attribute()):
for imported_ in node.names:
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement):
if import_statement not in self.transformers_imports:
@ -387,24 +386,24 @@ class DiffConverterTransformer(CSTTransformer):
self.imported_mapping[imported_class] = import_statement
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
if matchers.matches(original_node, matchers.SimpleStatementLine(body=[matchers.Assign()])):
assign = self.python_module.code_for_node(original_node.body[0])
node = original_node.body[0]
if m.matches(node.value, m.Name()) and assign in self.node_mapping:
if matchers.matches(node.value, matchers.Name()) and assign in self.node_mapping:
return self.node_mapping[assign]
# remove all relative imports made in the diff file
full_statement = self.python_module.code_for_node(updated_node.body[0])
if m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
if matchers.matches(updated_node, matchers.SimpleStatementLine(body=[matchers.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
self.all_imports.append(updated_node)
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if matchers.matches(updated_node, matchers.SimpleStatementLine(body=[matchers.Import()])):
self.all_imports.append(updated_node)
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if matchers.matches(parent_node, matchers.Module()):
self.new_body.append(updated_node)
return updated_node
@ -470,13 +469,13 @@ class DiffConverterTransformer(CSTTransformer):
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if matchers.matches(parent_node, matchers.Module()):
self.new_body.append(node)
return node
def leave_Expr(self, original_node: cst.Expr, node: cst.Expr) -> cst.Expr:
parent_node = self.get_metadata(cst.metadata.ScopeProvider, original_node)
if m.matches(parent_node, m.Module()):
if matchers.matches(parent_node, matchers.Module()):
self.new_body.append(node)
return node
@ -499,7 +498,7 @@ def convert_file(diff_file, cst_transformers=None):
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = fix_ruff(new_mod.code)
ruffed_code = run_ruff(new_mod.code, True)
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + ruffed_code)
@ -507,7 +506,7 @@ def convert_file(diff_file, cst_transformers=None):
if hasattr(cst_transformers, "config_body"):
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = fix_ruff(config_module.code)
ruffed_code = run_ruff(config_module.code, True)
f.write(AUTO_GENERATED_MESSAGE + ruffed_code)
# TODO optimize by re-using the class_finder