This commit is contained in:
Arthur Zucker 2024-05-15 16:44:46 +02:00
parent d3ab98e5ae
commit eaaf34f303
1 changed files with 66 additions and 63 deletions

View File

@ -43,77 +43,80 @@ def replace_super_calls_in_method(method_body, parent_method_body, method_name):
return method_body
# 2. Write all the classes. Use the `CohereConverter` class for this.
def create_single_model_file(converter):
model_identifier = converter.diff_file.split("diff_")
# temporarily add the source to the path in order to load everything?
# 1. Import all modules from the registered classes
modules = set([ _class.__module__ for _class in converter.registered_classes.values()]) or set()
for module in modules | {re.sub(r'.*src/(.*)\.py', r'\1', converter.diff_file).replace('/', '.')}:
modeling_ = importlib.import_module(module)
globals().update({k: getattr(modeling_, k) for k in modeling_.__dict__.keys()})
if hasattr(converter, "diff_file"):
model_identifier = converter.diff_file.split("diff_")
# temporarily add the source to the path in order to load everything?
# 1. Import all modules from the registered classes
modules = set([ _class.__module__ for _class in converter.registered_classes.values()]) or set()
for module in modules | {re.sub(r'.*src/(.*)\.py', r'\1', converter.diff_file).replace('/', '.')}:
modeling_ = importlib.import_module(module)
globals().update({k: getattr(modeling_, k) for k in modeling_.__dict__.keys()})
with open(converter.diff_file, 'r') as file, open(f"{model_identifier[0]}modeling_{model_identifier[1]}", "w+") as modeling:
modeling.write(APACHE_LICENCE)
function_set = {}
for line in file:
if "Converter.register" in line: # TODO use map() to map lines to this
# write the code of the original model
class_to_use, old_class = re.search(r'Converter\.register\(\"(.*?)\", (.*?)\)', line).groups()
model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', class_to_use)[0]
old_model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', old_class)[0]
# import all necessary modules from the path:
with open(converter.diff_file, 'r') as file, open(f"{model_identifier[0]}modeling_{model_identifier[1]}", "w+") as modeling:
modeling.write(APACHE_LICENCE)
function_set = {}
for line in file:
if "Converter.register" in line: # TODO use map() to map lines to this
# write the code of the original model
class_to_use, old_class = re.search(r'Converter\.register\(\"(.*?)\", (.*?)\)', line).groups()
model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', class_to_use)[0]
old_model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', old_class)[0]
# import all necessary modules from the path:
source_code = inspect.getsource(converter.registered_classes[class_to_use]).replace(old_class, class_to_use)
source_code = source_code.replace(old_model_identifier_camel, model_identifier_camel)
modeling.write(source_code)
modeling.write("\n")
source_code = inspect.getsource(converter.registered_classes[class_to_use]).replace(old_class, class_to_use)
source_code = source_code.replace(old_model_identifier_camel, model_identifier_camel)
modeling.write(source_code)
modeling.write("\n")
elif match:=re.match(r"class (\w+)\((\w+)\):", line):
class_name, parent_class = match.groups()
pattern = re.compile( r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n (?:def|@)|\Z))", re.MULTILINE)
parent_class_def = inspect.getsource(eval(parent_class))
modeling.write(parent_class_def.split('\n')[0].replace(parent_class,class_name)+"\n")
elif match:=re.match(r"class (\w+)\((\w+)\):", line):
class_name, parent_class = match.groups()
pattern = re.compile( r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n (?:def|@)|\Z))", re.MULTILINE)
parent_class_def = inspect.getsource(eval(parent_class))
modeling.write(parent_class_def.split('\n')[0].replace(parent_class,class_name)+"\n")
function_name_pattern = r"(?= def ([\S]*)\()"
function_body_pattern = r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n (?:def|@)|\Z))"
function_name_pattern = r"(?= def ([\S]*)\()"
function_body_pattern = r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n (?:def|@)|\Z))"
pattern = re.compile(function_body_pattern)
matches = pattern.finditer(parent_class_def)
parent_function_set = {}
for match in matches:
full_function = match.group()
print(full_function.split("def"))
if "def" in full_function:
parent_function_set[full_function.split("def")[1].split("(")[0]] = full_function
else:
parent_function_set[full_function] = full_function
child_function_set = parent_function_set.copy()
class_def = inspect.getsource(eval(class_name))
matches = pattern.finditer(class_def)
for match in matches:
# TODO handle call to super!
full_function = match.group()
if "def" in full_function:
function_name = full_function.split("def")[1].split("(")[0]
if f"super()." in full_function or f"return super()." in full_function:
replaced_function = replace_super_calls_in_method(full_function,
parent_function_set.get(function_name,
""),
function_name)
child_function_set[function_name] = replaced_function
pattern = re.compile(function_body_pattern)
matches = pattern.finditer(parent_class_def)
parent_function_set = {}
for match in matches:
full_function = match.group()
if "def" in full_function:
parent_function_set[full_function.split("def")[1].split("(")[0]] = full_function
else:
child_function_set[function_name] = full_function
else:
child_function_set[full_function] = full_function
parent_function_set[full_function] = full_function
modeling.write("\n".join(child_function_set.values())) # TODO we wrote the code, next lines shall be ignored
modeling.write("\n")
parent_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', parent_class)[0]
child_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', class_name)[0]
elif "= ModelConverter(__file__)" in line:
pass # don't write the converter to the result file
elif line not in "".join(function_set.values()) or line=="\n":
modeling.write(line)
child_function_set = parent_function_set.copy()
class_def = inspect.getsource(eval(class_name))
matches = pattern.finditer(class_def)
for match in matches:
# TODO handle call to super!
full_function = match.group()
if "def" in full_function:
function_name = full_function.split("def")[1].split("(")[0]
if (f"super()." in full_function or f"return super()." in full_function) and parent_identifier_camel != child_identifier_camel:
print(f"`{parent_identifier_camel}` `{child_identifier_camel}`")
replaced_function = replace_super_calls_in_method(full_function,
parent_function_set.get(function_name,
""),
function_name)
child_function_set[function_name] = replaced_function
else:
child_function_set[function_name] = full_function
else:
child_function_set[full_function] = full_function
modeling.write("\n".join(child_function_set.values())) # TODO we wrote the code, next lines shall be ignored
modeling.write("\n")
elif "= ModelConverter(__file__)" in line:
pass # don't write the converter to the result file
elif line not in "".join(function_set.values()) or line=="\n":
modeling.write(line)
def dynamically_import_object(module_path, object_name):