This commit is contained in:
Arthur Zucker 2024-05-28 10:32:50 +02:00
parent 1ce5c1b5c7
commit 099041413b
1 changed files with 36 additions and 12 deletions

View File

@ -37,7 +37,7 @@ class ClassFinder(CSTVisitor):
super().__init__(self)
self.value = init_value()
```
then the `class_dependency_mapping` should be: `{"LlamaModel":{"PreTrainedModel":{},"init_value":{}}, "init_value":{}}
then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
@ -48,7 +48,6 @@ class ClassFinder(CSTVisitor):
self.imports = {} # from flash_attn import
self.function_def = {} # def repeat_kv
self.assignments = {} # LLAMA_DOCSTRING
self.protected_imports = {} # if is_xxx_available()
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
def visit_ClassDef(self, node: ClassDef) -> None:
@ -114,7 +113,7 @@ class ClassFinder(CSTVisitor):
self.class_dependency_mapping[name] = dep
logger.info(f"Dict:\t\t{k.value.value:<45} called in {name}")
# Decorator: handle in leave_FunctionDef and leave_ClassDef instead
# Decorator: handle nodes used in the decorators
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
@ -256,6 +255,26 @@ class SuperTransformer(cst.CSTTransformer):
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
"""
Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
| super().__init__() | | super().__init__(config)
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
original_node = class_finder.classes[class_name]
original_methods = {f.name.value: f for f in original_node.body.body if m.matches(f, m.FunctionDef())}
@ -283,13 +302,16 @@ class DiffConverterTransformer(CSTTransformer):
def __init__(self, python_module):
super().__init__()
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # all the imports made from transformers.models.xxx
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = [] # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.node_mapping = {} # stores the name of the nodes that were added to the `new_body`
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = [] # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.all_imports = [] # just stores all of the imports
# fmt: on
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
@ -302,10 +324,9 @@ class DiffConverterTransformer(CSTTransformer):
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
"""
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
import_statement = self.python_module.code_for_node(node.module)
for imported_ in node.names:
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement):
if import_statement not in self.transformers_imports:
@ -314,6 +335,10 @@ class DiffConverterTransformer(CSTTransformer):
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
self.all_imports.append(node)
def visit_Import(self, node):
self.all_imports.append(node)
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
if m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
@ -382,7 +407,6 @@ class DiffConverterTransformer(CSTTransformer):
elif dependency not in self.inserted_deps:
# make sure the node is written after it's dependencies
node = self.node_mapping[dependency]
self.new_body.remove(node)
self.new_body.append(node)
self.inserted_deps.append(dependency)