fixup
This commit is contained in:
parent
d6ef9e81e5
commit
1ce5c1b5c7
|
@ -15,7 +15,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
)
|
||||
|
|
|
@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
|
|
@ -11,8 +11,10 @@ from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvide
|
|||
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_module_source_from_name(module_name: str) -> str:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
|
@ -296,7 +298,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
""" When visiting imports from `transformers.models.xxx` we need to:
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
1. Get the original source code
|
||||
2. Parse it into an AST Tree
|
||||
3. Add this import to `self.transformers_imports` as visited to not parse it twice
|
||||
|
@ -358,7 +360,9 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(self.transformers_imports[super_file_name], old_name, new_name)
|
||||
self.visited_module[super_file_name] = class_finder
|
||||
self.node_mapping[class_name] = class_finder.classes[class_name] # here we get the new node form the parent class
|
||||
self.node_mapping[class_name] = class_finder.classes[
|
||||
class_name
|
||||
] # here we get the new node form the parent class
|
||||
else: # we are re-using the previously parsed data
|
||||
class_finder = self.visited_module[super_file_name]
|
||||
|
||||
|
@ -435,6 +439,7 @@ def convert_file(diff_file, cst_transformers=None):
|
|||
# TODO optimize by re-using the class_finder
|
||||
return cst_transformers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
Loading…
Reference in New Issue