This commit is contained in:
Arthur Zucker 2024-05-28 10:50:26 +02:00
parent 494e6bac14
commit 6c486574bc
3 changed files with 5 additions and 9 deletions

View File

@ -15,7 +15,6 @@
# limitations under the License.
from transformers import PretrainedConfig

View File

@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_START_DOCSTRING,
)
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)

View File

@ -336,7 +336,6 @@ class DiffConverterTransformer(CSTTransformer):
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
assign = self.python_module.code_for_node(original_node.body[0])
@ -352,8 +351,8 @@ class DiffConverterTransformer(CSTTransformer):
return cst.RemoveFromParent()
self.all_imports.append(updated_node)
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
self.all_imports.append(updated_node)
self.all_imports.append(updated_node)
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.new_body.append(updated_node)
@ -435,9 +434,9 @@ class DiffConverterTransformer(CSTTransformer):
new_body = []
for visiter in self.visited_module.values():
new_body += list(visiter.imports.values())
# TODO for the config we need to sort the dependencies using `class_finder.`
if hasattr(self,"config_body"):
self.config_body = self.all_imports + self.config_body
# TODO for the config we need to sort the dependencies using `class_finder.`
if hasattr(self, "config_body"):
self.config_body = self.all_imports + self.config_body
return node.with_changes(body=[*new_body, *self.new_body])