fixup
This commit is contained in:
parent
d6ef9e81e5
commit
1ce5c1b5c7
|
@ -15,7 +15,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
)
|
||||
|
|
|
@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
|
|
@ -11,8 +11,10 @@ from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvide
|
|||
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_module_source_from_name(module_name: str) -> str:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
|
@ -282,12 +284,12 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
def __init__(self, python_module):
|
||||
super().__init__()
|
||||
self.python_module = python_module # we store the original module to use `code_for_node`
|
||||
self.transformers_imports = {} # all the imports made from transformers.models.xxx
|
||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
|
||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||
self.new_body = [] # store the new body, all global scope nodes should be added here
|
||||
self.inserted_deps = [] # nodes inserted via super dependency
|
||||
self.transformers_imports = {} # all the imports made from transformers.models.xxx
|
||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
|
||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||
self.new_body = [] # store the new body, all global scope nodes should be added here
|
||||
self.inserted_deps = [] # nodes inserted via super dependency
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
|
@ -296,8 +298,8 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
""" When visiting imports from `transformers.models.xxx` we need to:
|
||||
1. Get the original source code
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
1. Get the original source code
|
||||
2. Parse it into an AST Tree
|
||||
3. Add this import to `self.transformers_imports` as visited to not parse it twice
|
||||
|
||||
|
@ -338,7 +340,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
If they are from `transformers.models.xx` then:
|
||||
- take the AST tree of the module it comes from and parse it with a `ClassFinder`.
|
||||
- rename all every instance of `old_name` (llama) to `new_name` (gemma)
|
||||
2. We insert the modules which the inherited base depends on. This has to be done in
|
||||
2. We insert the modules which the inherited base depends on. This has to be done in
|
||||
the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
|
||||
then we remove it from the new body to add it again in the correct order.
|
||||
3. Replace the calls to `super().xxxx` merging parent code
|
||||
|
@ -358,8 +360,10 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
|
||||
self.visited_module[super_file_name] = class_finder
|
||||
self.node_mapping[class_name] = class_finder.classes[class_name] # here we get the new node form the parent class
|
||||
else: # we are re-using the previously parsed data
|
||||
self.node_mapping[class_name] = class_finder.classes[
|
||||
class_name
|
||||
] # here we get the new node form the parent class
|
||||
else: # we are re-using the previously parsed data
|
||||
class_finder = self.visited_module[super_file_name]
|
||||
|
||||
list_dependencies = {
|
||||
|
@ -431,10 +435,11 @@ def convert_file(diff_file, cst_transformers=None):
|
|||
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
|
||||
ruffed_code = fix_ruff(config_module.code)
|
||||
f.write(ruffed_code)
|
||||
|
||||
|
||||
# TODO optimize by re-using the class_finder
|
||||
return cst_transformers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
Loading…
Reference in New Issue