nit
This commit is contained in:
parent
f8587d7a2f
commit
07a90cc324
|
@ -128,7 +128,9 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
parent_package = self.transformers_mapping.get(node.value.value, None)
|
parent_package = self.transformers_mapping.get(node.value.value, None)
|
||||||
if parent_package:
|
if parent_package:
|
||||||
if parent_package not in self.visited_module:
|
if parent_package not in self.visited_module:
|
||||||
class_finder = find_classes_in_file(self.transformers_imports[parent_package])
|
old_name = re.findall(r'[A-Z][a-z0-9]*', node.value.value)[0].lower()
|
||||||
|
new_name = re.findall(r'[A-Z][a-z0-9]*', node.targets[0].target.value)[0].lower()
|
||||||
|
class_finder = find_classes_in_file(self.transformers_imports[parent_package], old_name, new_name)
|
||||||
self.visited_module[parent_package] = class_finder
|
self.visited_module[parent_package] = class_finder
|
||||||
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[
|
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[
|
||||||
parent_package
|
parent_package
|
||||||
|
@ -273,14 +275,31 @@ class SuperTransformer(cst.CSTTransformer):
|
||||||
from check_copies import run_ruff
|
from check_copies import run_ruff
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def convert_file(diff_file):
|
||||||
# Parse the Python file
|
# Parse the Python file
|
||||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
with open(diff_file, "r") as file:
|
||||||
code = file.read()
|
code = file.read()
|
||||||
module = cst.parse_module(code)
|
module = cst.parse_module(code)
|
||||||
transformers = DiffConverterTransformer(module)
|
transformers = DiffConverterTransformer(module)
|
||||||
new_mod = module.visit(transformers)
|
new_mod = module.visit(transformers)
|
||||||
ruffed_code = run_ruff(new_mod.code)
|
ruffed_code = run_ruff(new_mod.code)
|
||||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/modeling_gemma.py", "w") as f:
|
with open(diff_file.replace("diff_","modeling_"), "w") as f:
|
||||||
f.write(ruffed_code)
|
f.write(ruffed_code)
|
||||||
exit(0)
|
|
||||||
|
import glob
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--files_to_parse",
|
||||||
|
default="all",
|
||||||
|
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.files_to_parse == "all":
|
||||||
|
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
||||||
|
for file_name in args.files_to_parse:
|
||||||
|
print(f"Converting {file_name} to a single model single file format")
|
||||||
|
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||||
|
converter = convert_file(file_name)
|
||||||
|
|
Loading…
Reference in New Issue