nit
This commit is contained in:
parent
f8587d7a2f
commit
07a90cc324
|
@ -128,7 +128,9 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
parent_package = self.transformers_mapping.get(node.value.value, None)
|
||||
if parent_package:
|
||||
if parent_package not in self.visited_module:
|
||||
class_finder = find_classes_in_file(self.transformers_imports[parent_package])
|
||||
old_name = re.findall(r'[A-Z][a-z0-9]*', node.value.value)[0].lower()
|
||||
new_name = re.findall(r'[A-Z][a-z0-9]*', node.targets[0].target.value)[0].lower()
|
||||
class_finder = find_classes_in_file(self.transformers_imports[parent_package], old_name, new_name)
|
||||
self.visited_module[parent_package] = class_finder
|
||||
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[
|
||||
parent_package
|
||||
|
@ -273,14 +275,31 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
from check_copies import run_ruff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def convert_file(diff_file):
|
||||
# Parse the Python file
|
||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
||||
with open(diff_file, "r") as file:
|
||||
code = file.read()
|
||||
module = cst.parse_module(code)
|
||||
transformers = DiffConverterTransformer(module)
|
||||
new_mod = module.visit(transformers)
|
||||
ruffed_code = run_ruff(new_mod.code)
|
||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/modeling_gemma.py", "w") as f:
|
||||
with open(diff_file.replace("diff_","modeling_"), "w") as f:
|
||||
f.write(ruffed_code)
|
||||
exit(0)
|
||||
|
||||
import glob
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default="all",
|
||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == "all":
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
||||
for file_name in args.files_to_parse:
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converter = convert_file(file_name)
|
||||
|
|
Loading…
Reference in New Issue