This commit is contained in:
Arthur Zucker 2024-05-28 09:38:22 +02:00
parent d6ef9e81e5
commit 1ce5c1b5c7
3 changed files with 17 additions and 15 deletions

View File

@ -15,7 +15,6 @@
# limitations under the License.
from ...utils import (
is_flash_attn_2_available,
)

View File

@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)

View File

@ -11,8 +11,10 @@ from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvide
from transformers import logging
logger = logging.get_logger(__name__)
def get_module_source_from_name(module_name: str) -> str:
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
@ -296,7 +298,7 @@ class DiffConverterTransformer(CSTTransformer):
return node
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
""" When visiting imports from `transformers.models.xxx` we need to:
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
@ -358,7 +360,9 @@ class DiffConverterTransformer(CSTTransformer):
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
self.visited_module[super_file_name] = class_finder
self.node_mapping[class_name] = class_finder.classes[class_name] # here we get the new node form the parent class
self.node_mapping[class_name] = class_finder.classes[
class_name
] # here we get the new node form the parent class
else: # we are re-using the previously parsed data
class_finder = self.visited_module[super_file_name]
@ -435,6 +439,7 @@ def convert_file(diff_file, cst_transformers=None):
# TODO optimize by re-using the class_finder
return cst_transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(