nits
This commit is contained in:
parent
8a85473357
commit
513b933b60
|
@ -17,8 +17,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch LLaMA model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
|
|
@ -23,15 +23,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
|
@ -50,14 +47,16 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
"""PyTorch LLaMA model."""
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -103,25 +103,24 @@ class ClassFinder(CSTVisitor):
|
|||
def leave_If(self, node):
|
||||
for stmt in node.body.body:
|
||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
||||
# TODO this can actually be a duplicated import?
|
||||
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
|
||||
self.imports[stmt.body[0].names] = node
|
||||
|
||||
def leave_Name(self, node):
|
||||
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
|
||||
dad = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
|
||||
self._update_class_dependency(dad._name_prefix.split(".")[0], node.value)
|
||||
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
|
||||
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
|
||||
|
||||
def leave_Arg(self, node):
|
||||
if m.matches(node.value, m.Name()):
|
||||
dad = self.get_metadata(ParentNodeProvider, node)
|
||||
if m.matches(dad, m.ClassDef()) and dad.bases:
|
||||
self._update_class_dependency(dad.name.value, node.value.value)
|
||||
parent = self.get_metadata(ParentNodeProvider, node)
|
||||
if m.matches(parent, m.ClassDef()) and parent.bases:
|
||||
self._update_class_dependency(parent.name.value, node.value.value)
|
||||
|
||||
def leave_Dict(self, node):
|
||||
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
|
||||
name = dad.targets[0].target.value
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
|
||||
name = parent.targets[0].target.value
|
||||
if name in self.assignments:
|
||||
for k in node.elements:
|
||||
dep_name = k.value.value
|
||||
|
@ -132,9 +131,9 @@ class ClassFinder(CSTVisitor):
|
|||
if hasattr(node.decorator, "args"):
|
||||
for k in node.decorator.args:
|
||||
if k.value.value in self.assignments:
|
||||
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else dad.name.value
|
||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||
self._update_class_dependency(name, k.value.value)
|
||||
|
||||
def leave_Module(self, node):
|
||||
|
@ -304,11 +303,9 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[
|
||||
0
|
||||
] # get the indented block
|
||||
|
||||
return original_node.with_changes(body=new_replacement_class.body)
|
||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
|
||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||
return original_node.with_changes(body=new_replacement_body)
|
||||
|
||||
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
|
@ -327,13 +324,6 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
self.global_scope_index = 0
|
||||
# fmt: on
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.global_scope_index += 100
|
||||
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
1. Get the original source code
|
||||
|
@ -351,6 +341,13 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
imported_class = self.python_module.code_for_node(imported_.name)
|
||||
self.imported_mapping[imported_class] = import_statement
|
||||
|
||||
def leave_FunctionDef(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
self.global_scope_index += 100
|
||||
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
||||
return node
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node, updated_node):
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
self.all_imports.append(updated_node)
|
||||
|
@ -423,15 +420,15 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
|
||||
# def leave_If(self, original_node, node):
|
||||
# parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
# if m.matches(parent_node, m.Module()):
|
||||
# full_statement = self.python_module.code_for_node(original_node.test)
|
||||
# if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||
# self.all_imports.append(node)
|
||||
# else:
|
||||
# self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
|
||||
# return node
|
||||
def leave_If(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
full_statement = self.python_module.code_for_node(original_node.test)
|
||||
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||
self.all_imports.append(node)
|
||||
elif full_statement not in self.new_body:
|
||||
self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
|
||||
return node
|
||||
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||
|
|
Loading…
Reference in New Issue