nits
This commit is contained in:
parent
8a85473357
commit
513b933b60
|
@ -17,8 +17,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch LLaMA model."""
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
|
@ -23,15 +23,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
@ -50,14 +47,16 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_llama import LlamaConfig
|
from .configuration_llama import LlamaConfig
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
"""PyTorch LLaMA model."""
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
|
@ -103,25 +103,24 @@ class ClassFinder(CSTVisitor):
|
||||||
def leave_If(self, node):
|
def leave_If(self, node):
|
||||||
for stmt in node.body.body:
|
for stmt in node.body.body:
|
||||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
||||||
# TODO this can actually be a duplicated import?
|
self.imports[stmt.body[0].names] = node
|
||||||
self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it
|
|
||||||
|
|
||||||
def leave_Name(self, node):
|
def leave_Name(self, node):
|
||||||
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
|
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
|
||||||
dad = self.get_metadata(cst.metadata.ScopeProvider, node)
|
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||||
if not isinstance(dad, cst.metadata.scope_provider.GlobalScope):
|
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
|
||||||
self._update_class_dependency(dad._name_prefix.split(".")[0], node.value)
|
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
|
||||||
|
|
||||||
def leave_Arg(self, node):
|
def leave_Arg(self, node):
|
||||||
if m.matches(node.value, m.Name()):
|
if m.matches(node.value, m.Name()):
|
||||||
dad = self.get_metadata(ParentNodeProvider, node)
|
parent = self.get_metadata(ParentNodeProvider, node)
|
||||||
if m.matches(dad, m.ClassDef()) and dad.bases:
|
if m.matches(parent, m.ClassDef()) and parent.bases:
|
||||||
self._update_class_dependency(dad.name.value, node.value.value)
|
self._update_class_dependency(parent.name.value, node.value.value)
|
||||||
|
|
||||||
def leave_Dict(self, node):
|
def leave_Dict(self, node):
|
||||||
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
if m.matches(dad, m.Assign(targets=[m.AssignTarget()])):
|
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
|
||||||
name = dad.targets[0].target.value
|
name = parent.targets[0].target.value
|
||||||
if name in self.assignments:
|
if name in self.assignments:
|
||||||
for k in node.elements:
|
for k in node.elements:
|
||||||
dep_name = k.value.value
|
dep_name = k.value.value
|
||||||
|
@ -132,9 +131,9 @@ class ClassFinder(CSTVisitor):
|
||||||
if hasattr(node.decorator, "args"):
|
if hasattr(node.decorator, "args"):
|
||||||
for k in node.decorator.args:
|
for k in node.decorator.args:
|
||||||
if k.value.value in self.assignments:
|
if k.value.value in self.assignments:
|
||||||
dad = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
|
||||||
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else dad.name.value
|
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
|
||||||
self._update_class_dependency(name, k.value.value)
|
self._update_class_dependency(name, k.value.value)
|
||||||
|
|
||||||
def leave_Module(self, node):
|
def leave_Module(self, node):
|
||||||
|
@ -304,11 +303,9 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||||
temp_module = cst.Module(body=[result_node])
|
temp_module = cst.Module(body=[result_node])
|
||||||
new_module = MetadataWrapper(temp_module)
|
new_module = MetadataWrapper(temp_module)
|
||||||
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[
|
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
|
||||||
0
|
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||||
] # get the indented block
|
return original_node.with_changes(body=new_replacement_body)
|
||||||
|
|
||||||
return original_node.with_changes(body=new_replacement_class.body)
|
|
||||||
|
|
||||||
|
|
||||||
class DiffConverterTransformer(CSTTransformer):
|
class DiffConverterTransformer(CSTTransformer):
|
||||||
|
@ -327,13 +324,6 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
self.global_scope_index = 0
|
self.global_scope_index = 0
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def leave_FunctionDef(self, original_node, node):
|
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
|
||||||
if m.matches(parent_node, m.Module()):
|
|
||||||
self.global_scope_index += 100
|
|
||||||
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||||
1. Get the original source code
|
1. Get the original source code
|
||||||
|
@ -351,6 +341,13 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
imported_class = self.python_module.code_for_node(imported_.name)
|
imported_class = self.python_module.code_for_node(imported_.name)
|
||||||
self.imported_mapping[imported_class] = import_statement
|
self.imported_mapping[imported_class] = import_statement
|
||||||
|
|
||||||
|
def leave_FunctionDef(self, original_node, node):
|
||||||
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
|
if m.matches(parent_node, m.Module()):
|
||||||
|
self.global_scope_index += 100
|
||||||
|
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
|
||||||
|
return node
|
||||||
|
|
||||||
def leave_SimpleStatementLine(self, original_node, updated_node):
|
def leave_SimpleStatementLine(self, original_node, updated_node):
|
||||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
|
@ -423,15 +420,15 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
# def leave_If(self, original_node, node):
|
def leave_If(self, original_node, node):
|
||||||
# parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
# if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
# full_statement = self.python_module.code_for_node(original_node.test)
|
full_statement = self.python_module.code_for_node(original_node.test)
|
||||||
# if re.search(r"[\s\S]*is_.*available", full_statement):
|
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||||
# self.all_imports.append(node)
|
self.all_imports.append(node)
|
||||||
# else:
|
elif full_statement not in self.new_body:
|
||||||
# self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
|
self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node}
|
||||||
# return node
|
return node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Assign, node):
|
def leave_Module(self, original_node: cst.Assign, node):
|
||||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||||
|
|
Loading…
Reference in New Issue