This commit is contained in:
Arthur Zucker 2024-05-18 12:49:27 +02:00
parent f8587d7a2f
commit 07a90cc324
1 changed files with 24 additions and 5 deletions

View File

@ -128,7 +128,9 @@ class DiffConverterTransformer(CSTTransformer):
parent_package = self.transformers_mapping.get(node.value.value, None) parent_package = self.transformers_mapping.get(node.value.value, None)
if parent_package: if parent_package:
if parent_package not in self.visited_module: if parent_package not in self.visited_module:
class_finder = find_classes_in_file(self.transformers_imports[parent_package]) old_name = re.findall(r'[A-Z][a-z0-9]*', node.value.value)[0].lower()
new_name = re.findall(r'[A-Z][a-z0-9]*', node.targets[0].target.value)[0].lower()
class_finder = find_classes_in_file(self.transformers_imports[parent_package], old_name, new_name)
self.visited_module[parent_package] = class_finder self.visited_module[parent_package] = class_finder
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[ self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[
parent_package parent_package
@ -273,14 +275,31 @@ class SuperTransformer(cst.CSTTransformer):
from check_copies import run_ruff from check_copies import run_ruff
if __name__ == "__main__": def convert_file(diff_file):
# Parse the Python file # Parse the Python file
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file: with open(diff_file, "r") as file:
code = file.read() code = file.read()
module = cst.parse_module(code) module = cst.parse_module(code)
transformers = DiffConverterTransformer(module) transformers = DiffConverterTransformer(module)
new_mod = module.visit(transformers) new_mod = module.visit(transformers)
ruffed_code = run_ruff(new_mod.code) ruffed_code = run_ruff(new_mod.code)
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/modeling_gemma.py", "w") as f: with open(diff_file.replace("diff_","modeling_"), "w") as f:
f.write(ruffed_code) f.write(ruffed_code)
exit(0)
import glob
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default="all",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
args = parser.parse_args()
if args.files_to_parse == "all":
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name)