This commit is contained in:
Arthur Zucker 2024-05-30 16:05:11 +02:00
parent 8a85473357
commit 513b933b60
3 changed files with 35 additions and 41 deletions

View File

@ -17,8 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union

View File

@ -23,15 +23,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -50,14 +47,16 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
"""PyTorch LLaMA model."""
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)

View File

@ -103,25 +103,24 @@ class ClassFinder(CSTVisitor):
def leave_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
# TODO this can actually be a duplicated import?
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
self.imports[stmt.body[0].names] = node
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
dad = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
self._update_class_dependency(dad._name_prefix.split(".")[0], node.value)
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
dad = self.get_metadata(ParentNodeProvider, node)
if m.matches(dad, m.ClassDef()) and dad.bases:
self._update_class_dependency(dad.name.value, node.value.value)
parent = self.get_metadata(ParentNodeProvider, node)
if m.matches(parent, m.ClassDef()) and parent.bases:
self._update_class_dependency(parent.name.value, node.value.value)
def leave_Dict(self, node):
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
name = dad.targets[0].target.value
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
name = parent.targets[0].target.value
if name in self.assignments:
for k in node.elements:
dep_name = k.value.value
@ -132,9 +131,9 @@ class ClassFinder(CSTVisitor):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if k.value.value in self.assignments:
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else dad.name.value
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.value)
def leave_Module(self, node):
@ -304,11 +303,9 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[
0
] # get the indented block
return original_node.with_changes(body=new_replacement_class.body)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
new_replacement_body = new_replacement_class.body[0].body # get the indented block
return original_node.with_changes(body=new_replacement_body)
class DiffConverterTransformer(CSTTransformer):
@ -327,13 +324,6 @@ class DiffConverterTransformer(CSTTransformer):
self.global_scope_index = 0
# fmt: on
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
@ -351,6 +341,13 @@ class DiffConverterTransformer(CSTTransformer):
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_SimpleStatementLine(self, original_node, updated_node):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
self.all_imports.append(updated_node)
@ -423,15 +420,15 @@ class DiffConverterTransformer(CSTTransformer):
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
# def leave_If(self, original_node, node):
# parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
# if m.matches(parent_node, m.Module()):
# full_statement = self.python_module.code_for_node(original_node.test)
# if re.search(r"[\s\S]*is_.*available", full_statement):
# self.all_imports.append(node)
# else:
# self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
# return node
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_imports.append(node)
elif full_statement not in self.new_body:
self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
return node
def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}