nits
This commit is contained in:
parent
494e6bac14
commit
6c486574bc
|
@ -15,7 +15,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
|
|
|
@ -812,7 +812,6 @@ _CONFIG_FOR_DOC = "GemmaConfig"
|
|||
GEMMA_START_DOCSTRING,
|
||||
)
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -1032,7 +1031,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
|
||||
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
|
|
|
@ -336,7 +336,6 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
assign = self.python_module.code_for_node(original_node.body[0])
|
||||
|
@ -352,8 +351,8 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
return cst.RemoveFromParent()
|
||||
self.all_imports.append(updated_node)
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
self.all_imports.append(updated_node)
|
||||
|
||||
self.all_imports.append(updated_node)
|
||||
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.new_body.append(updated_node)
|
||||
|
@ -435,9 +434,9 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
new_body = []
|
||||
for visiter in self.visited_module.values():
|
||||
new_body += list(visiter.imports.values())
|
||||
# TODO for the config we need to sort the dependencies using `class_finder.`
|
||||
if hasattr(self,"config_body"):
|
||||
self.config_body = self.all_imports + self.config_body
|
||||
# TODO for the config we need to sort the dependencies using `class_finder.`
|
||||
if hasattr(self, "config_body"):
|
||||
self.config_body = self.all_imports + self.config_body
|
||||
return node.with_changes(body=[*new_body, *self.new_body])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue