From 494e6bac147e511b7cc7710732434443fe919424 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 28 May 2024 10:45:04 +0200 Subject: [PATCH] update --- .../models/gemma/configuration_gemma.py | 12 +----------- src/transformers/models/gemma/modeling_gemma.py | 2 ++ utils/diff_model_converter.py | 13 +++++++------ 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 08a581faeb..4f5c772600 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -15,18 +15,8 @@ # limitations under the License. -from ...utils import ( - is_flash_attn_2_available, -) -from .configuration_gemma import GemmaConfig - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from transformers import PretrainedConfig class GemmaConfig(PretrainedConfig): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 657e66d0e3..40fcd93ff1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -812,6 +812,7 @@ _CONFIG_FOR_DOC = "GemmaConfig" GEMMA_START_DOCSTRING, ) class GemmaModel(GemmaPreTrainedModel): + def __init__(self, config: GemmaConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -1031,6 +1032,7 @@ class GemmaModel(GemmaPreTrainedModel): class GemmaForCausalLM(GemmaPreTrainedModel): + def __init__(self, config): super().__init__(config) self.model = GemmaModel(config) diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index 5783ff08fb..1b0b5e730b 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -335,10 +335,7 @@ class DiffConverterTransformer(CSTTransformer): self.transformers_imports[import_statement] = tree imported_class = self.python_module.code_for_node(imported_.name) self.imported_mapping[imported_class] = import_statement - self.all_imports.append(node) - def visit_Import(self, node): - self.all_imports.append(node) def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode): if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): @@ -353,7 +350,10 @@ class DiffConverterTransformer(CSTTransformer): full_statement = self.python_module.code_for_node(updated_node.body[0].module) if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement): return cst.RemoveFromParent() - + self.all_imports.append(updated_node) + if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): + self.all_imports.append(updated_node) + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): self.new_body.append(updated_node) @@ -435,8 +435,9 @@ class DiffConverterTransformer(CSTTransformer): new_body = [] for visiter in self.visited_module.values(): new_body += list(visiter.imports.values()) - # TODO for the config we need to sort the dependencies using `class_finder.` - self.config_body = list(visiter.imports.values()) + self.config_body + # TODO for the config we need to sort the dependencies using `class_finder.` + if hasattr(self,"config_body"): + self.config_body = self.all_imports + self.config_body return node.with_changes(body=[*new_body, *self.new_body])