This commit is contained in:
parent
1ce5c1b5c7
commit
099041413b
|
@ -37,7 +37,7 @@ class ClassFinder(CSTVisitor):
|
|||
super().__init__(self)
|
||||
self.value = init_value()
|
||||
```
|
||||
then the `class_dependency_mapping` should be: `{"LlamaModel":{"PreTrainedModel":{},"init_value":{}}, "init_value":{}}
|
||||
then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}
|
||||
"""
|
||||
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
@ -48,7 +48,6 @@ class ClassFinder(CSTVisitor):
|
|||
self.imports = {} # from flash_attn import
|
||||
self.function_def = {} # def repeat_kv
|
||||
self.assignments = {} # LLAMA_DOCSTRING
|
||||
self.protected_imports = {} # if is_xxx_available()
|
||||
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
|
||||
|
||||
def visit_ClassDef(self, node: ClassDef) -> None:
|
||||
|
@ -114,7 +113,7 @@ class ClassFinder(CSTVisitor):
|
|||
self.class_dependency_mapping[name] = dep
|
||||
logger.info(f"Dict:\t\t{k.value.value:<45} called in {name}")
|
||||
|
||||
# Decorator: handle in leave_FunctionDef and leave_ClassDef instead
|
||||
# Decorator: handle nodes used in the decorators
|
||||
def leave_Decorator(self, node):
|
||||
if hasattr(node.decorator, "args"):
|
||||
for k in node.decorator.args:
|
||||
|
@ -256,6 +255,26 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
|
||||
|
||||
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
|
||||
"""
|
||||
Given the `class_name`, the `updated_node`'s call to super are unpacked.
|
||||
|
||||
| ```python | | ```python
|
||||
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
||||
| def __init__(self): | | def __init__(self):
|
||||
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
|
||||
| super().__init__() | | super().__init__(config)
|
||||
| ``` | | self.padding_idx = config.pad_token_id
|
||||
| self.vocab_size = config.vocab_size
|
||||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
| self.layers = nn.ModuleList(
|
||||
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
| )
|
||||
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
| self.gradient_checkpointing = False
|
||||
| # Initialize weights and apply final processing
|
||||
| self.post_init()
|
||||
| ```
|
||||
"""
|
||||
original_node = class_finder.classes[class_name]
|
||||
original_methods = {f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())}
|
||||
|
||||
|
@ -283,13 +302,16 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
|
||||
def __init__(self, python_module):
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
self.python_module = python_module # we store the original module to use `code_for_node`
|
||||
self.transformers_imports = {} # all the imports made from transformers.models.xxx
|
||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
|
||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||
self.new_body = [] # store the new body, all global scope nodes should be added here
|
||||
self.inserted_deps = [] # nodes inserted via super dependency
|
||||
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
||||
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
|
||||
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
|
||||
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
|
||||
self.new_body = [] # store the new body, all global scope nodes should be added here
|
||||
self.inserted_deps = [] # nodes inserted via super dependency
|
||||
self.all_imports = [] # just stores all of the imports
|
||||
# fmt: on
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
|
@ -302,10 +324,9 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
1. Get the original source code
|
||||
2. Parse it into an AST Tree
|
||||
3. Add this import to `self.transformers_imports` as visited to not parse it twice
|
||||
|
||||
"""
|
||||
import_statement = self.python_module.code_for_node(node.module)
|
||||
if m.matches(node.module, m.Attribute()):
|
||||
import_statement = self.python_module.code_for_node(node.module)
|
||||
for imported_ in node.names:
|
||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement):
|
||||
if import_statement not in self.transformers_imports:
|
||||
|
@ -314,6 +335,10 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
self.transformers_imports[import_statement] = tree
|
||||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
self.all_imports.append(node)
|
||||
|
||||
def visit_Import(self, node):
|
||||
self.all_imports.append(node)
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
|
||||
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
|
@ -382,7 +407,6 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
elif dependency not in self.inserted_deps:
|
||||
# make sure the node is written after it's dependencies
|
||||
node = self.node_mapping[dependency]
|
||||
|
||||
self.new_body.remove(node)
|
||||
self.new_body.append(node)
|
||||
self.inserted_deps.append(dependency)
|
||||
|
|
Loading…
Reference in New Issue