nits
This commit is contained in:
parent
494e6bac14
commit
6c486574bc
|
@ -15,7 +15,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
||||||
GEMMA_START_DOCSTRING,
|
GEMMA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GemmaModel(GemmaPreTrainedModel):
|
class GemmaModel(GemmaPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config: GemmaConfig):
|
def __init__(self, config: GemmaConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
|
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = GemmaModel(config)
|
self.model = GemmaModel(config)
|
||||||
|
|
|
@ -336,7 +336,6 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
imported_class = self.python_module.code_for_node(imported_.name)
|
imported_class = self.python_module.code_for_node(imported_.name)
|
||||||
self.imported_mapping[imported_class] = import_statement
|
self.imported_mapping[imported_class] = import_statement
|
||||||
|
|
||||||
|
|
||||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||||
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
assign = self.python_module.code_for_node(original_node.body[0])
|
assign = self.python_module.code_for_node(original_node.body[0])
|
||||||
|
@ -352,8 +351,8 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
return cst.RemoveFromParent()
|
return cst.RemoveFromParent()
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
|
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
self.new_body.append(updated_node)
|
self.new_body.append(updated_node)
|
||||||
|
@ -435,9 +434,9 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
new_body = []
|
new_body = []
|
||||||
for visiter in self.visited_module.values():
|
for visiter in self.visited_module.values():
|
||||||
new_body += list(visiter.imports.values())
|
new_body += list(visiter.imports.values())
|
||||||
# TODO for the config we need to sort the dependencies using `class_finder.`
|
# TODO for the config we need to sort the dependencies using `class_finder.`
|
||||||
if hasattr(self,"config_body"):
|
if hasattr(self, "config_body"):
|
||||||
self.config_body = self.all_imports + self.config_body
|
self.config_body = self.all_imports + self.config_body
|
||||||
return node.with_changes(body=[*new_body, *self.new_body])
|
return node.with_changes(body=[*new_body, *self.new_body])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue