This commit is contained in:
Arthur Zucker 2024-05-28 09:38:22 +02:00
parent d6ef9e81e5
commit 1ce5c1b5c7
3 changed files with 17 additions and 15 deletions

View File

@ -15,7 +15,6 @@
# limitations under the License.
from ...utils import (
is_flash_attn_2_available,
)

View File

@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)

View File

@ -11,8 +11,10 @@ from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvide
from transformers import logging
logger = logging.get_logger(__name__)
def get_module_source_from_name(module_name: str) -> str:
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
@ -282,12 +284,12 @@ class DiffConverterTransformer(CSTTransformer):
def __init__(self, python_module):
super().__init__()
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # all the imports made from transformers.models.xxx
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = [] # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.transformers_imports = {} # all the imports made from transformers.models.xxx
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = [] # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
@ -296,8 +298,8 @@ class DiffConverterTransformer(CSTTransformer):
return node
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
""" When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
@ -338,7 +340,7 @@ class DiffConverterTransformer(CSTTransformer):
If they are from `transformers.models.xx` then:
- take the AST tree of the module it comes from and parse it with a `ClassFinder`.
- rename all every instance of `old_name` (llama) to `new_name` (gemma)
2. We insert the modules which the inherited base depends on. This has to be done in
2. We insert the modules which the inherited base depends on. This has to be done in
the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
then we remove it from the new body to add it again in the correct order.
3. Replace the calls to `super().xxxx` merging parent code
@ -358,8 +360,10 @@ class DiffConverterTransformer(CSTTransformer):
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
self.visited_module[super_file_name] = class_finder
self.node_mapping[class_name] = class_finder.classes[class_name] # here we get the new node form the parent class
else: # we are re-using the previously parsed data
self.node_mapping[class_name] = class_finder.classes[
class_name
] # here we get the new node form the parent class
else: # we are re-using the previously parsed data
class_finder = self.visited_module[super_file_name]
list_dependencies = {
@ -431,10 +435,11 @@ def convert_file(diff_file, cst_transformers=None):
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = fix_ruff(config_module.code)
f.write(ruffed_code)
# TODO optimize by re-using the class_finder
return cst_transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(