No need for call

This commit is contained in:
Arthur Zucker 2024-05-25 10:05:28 +02:00
parent 585686ed08
commit 91f45f8ff2
1 changed files with 12 additions and 10 deletions

View File

@ -59,19 +59,19 @@ class ClassFinder(CSTVisitor):
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
def leave_Call(self, node):
if self.python_module.code_for_node(node.func) in self.function_def or node.func.value in self.classes:
dad = self.get_metadata(cst.metadata.ScopeProvider,node.func)
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
if hasattr(dad, "_name_prefix"):
print(f"Call: {node.func.value:<15} called in {dad._name_prefix:>10}")
# def leave_Call(self, node):
# if self.python_module.code_for_node(node.func) in self.function_def or node.func.value in self.classes:
# dad = self.get_metadata(cst.metadata.ScopeProvider,node.func)
# if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
# if hasattr(dad, "_name_prefix"):
# print(f"Call:\t\t{node.func.value:<45} called in {dad._name_prefix}")
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
dad = self.get_metadata(cst.metadata.ScopeProvider,node)
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
if hasattr(dad, "_name_prefix"):
print(f"Name: {node.value:<15} called in {dad._name_prefix:>10}, {dad.name}")
print(f"Name:\t\t{node.value:<45} called in {dad._name_prefix}")
def leave_Dict(self, node):
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
@ -80,14 +80,14 @@ class ClassFinder(CSTVisitor):
if name in self.assignments:
for k in node.elements:
if k.value.value in self.classes:
print(f"Dict: {k.value.value:<15} called in {name:>10}")
print(f"Dict:\t\t{k.value.value:<45} called in {name}")
# Decorator: in leave_FunctionDef and leave_ClassDef
# Decorator: handle in leave_FunctionDef and leave_ClassDef instead
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if k.value.value in self.assignments:
print(f"Decorator: {k.value.value} called in {self.get_metadata(cst.metadata.ParentNodeProvider, node).name.value}")
print(f"Decorator:\t{k.value.value:<45} called in {self.get_metadata(cst.metadata.ParentNodeProvider, node).name.value}")
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def __init__(self, old_name, new_name):
@ -194,6 +194,7 @@ class DiffConverterTransformer(CSTTransformer):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.[modeling|configuration]_.*", full_statement):
# TODO use remove from parent. `return RemoveFromParent()`
return updated_node.with_changes(body=[])
return updated_node
@ -347,6 +348,7 @@ def convert_file(diff_file):
transformers = DiffConverterTransformer(module)
new_mod = module.visit(transformers)
ruffed_code = fix_ruff(new_mod.code)
exit(0)
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(ruffed_code)