This commit is contained in:
Arthur Zucker 2024-05-28 10:45:04 +02:00
parent 54af8877cb
commit 494e6bac14
3 changed files with 10 additions and 17 deletions

View File

@ -15,18 +15,8 @@
# limitations under the License.
from ...utils import (
is_flash_attn_2_available,
)
from .configuration_gemma import GemmaConfig
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from transformers import PretrainedConfig
class GemmaConfig(PretrainedConfig):

View File

@ -812,6 +812,7 @@ _CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1031,6 +1032,7 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)

View File

@ -335,10 +335,7 @@ class DiffConverterTransformer(CSTTransformer):
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
self.all_imports.append(node)
def visit_Import(self, node):
self.all_imports.append(node)
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
@ -353,7 +350,10 @@ class DiffConverterTransformer(CSTTransformer):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
self.all_imports.append(updated_node)
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
self.all_imports.append(updated_node)
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.new_body.append(updated_node)
@ -435,8 +435,9 @@ class DiffConverterTransformer(CSTTransformer):
new_body = []
for visiter in self.visited_module.values():
new_body += list(visiter.imports.values())
# TODO for the config we need to sort the dependencies using `class_finder.`
self.config_body = list(visiter.imports.values()) + self.config_body
# TODO for the config we need to sort the dependencies using `class_finder.`
if hasattr(self,"config_body"):
self.config_body = self.all_imports + self.config_body
return node.with_changes(body=[*new_body, *self.new_body])