update
This commit is contained in:
parent
54af8877cb
commit
494e6bac14
|
@ -15,18 +15,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
)
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
|
|
|
@ -812,6 +812,7 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1031,6 +1032,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
|
|
@ -335,10 +335,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
self.transformers_imports[import_statement] = tree
|
||||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
self.all_imports.append(node)
|
||||
|
||||
def visit_Import(self, node):
|
||||
self.all_imports.append(node)
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
|
@ -353,7 +350,10 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
|
||||
return cst.RemoveFromParent()
|
||||
|
||||
self.all_imports.append(updated_node)
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
self.all_imports.append(updated_node)
|
||||
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.new_body.append(updated_node)
|
||||
|
@ -435,8 +435,9 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
new_body = []
|
||||
for visiter in self.visited_module.values():
|
||||
new_body += list(visiter.imports.values())
|
||||
# TODO for the config we need to sort the dependencies using `class_finder.`
|
||||
self.config_body = list(visiter.imports.values()) + self.config_body
|
||||
# TODO for the config we need to sort the dependencies using `class_finder.`
|
||||
if hasattr(self,"config_body"):
|
||||
self.config_body = self.all_imports + self.config_body
|
||||
return node.with_changes(body=[*new_body, *self.new_body])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue