diff --git a/src/transformers/models/llama/diff_llama.py b/src/transformers/models/llama/diff_llama.py index 226d14c18b..836528ee21 100644 --- a/src/transformers/models/llama/diff_llama.py +++ b/src/transformers/models/llama/diff_llama.py @@ -17,8 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch LLaMA model.""" - import math from typing import List, Optional, Tuple, Union diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 83f8c650a1..75f7e906a1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -23,15 +23,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint -from flash_attn import flash_attn_func, flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -50,14 +47,16 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_llama import LlamaConfig - -"""PyTorch LLaMA model.""" +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa logger = logging.get_logger(__name__) diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index 7fabe57703..37943c6a3e 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -103,25 +103,24 @@ class ClassFinder(CSTVisitor): def leave_If(self, node): for stmt in node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - # TODO this can actually be a duplicated import? - self.imports[stmt.body[0].names] = node # match the visit simple statement line to overwrite it + self.imports[stmt.body[0].names] = node def leave_Name(self, node): if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): - dad = self.get_metadata(cst.metadata.ScopeProvider, node) - if not isinstance(dad, cst.metadata.scope_provider.GlobalScope): - self._update_class_dependency(dad._name_prefix.split(".")[0], node.value) + parent = self.get_metadata(cst.metadata.ScopeProvider, node) + if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): + self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) def leave_Arg(self, node): if m.matches(node.value, m.Name()): - dad = self.get_metadata(ParentNodeProvider, node) - if m.matches(dad, m.ClassDef()) and dad.bases: - self._update_class_dependency(dad.name.value, node.value.value) + parent = self.get_metadata(ParentNodeProvider, node) + if m.matches(parent, m.ClassDef()) and parent.bases: + self._update_class_dependency(parent.name.value, node.value.value) def leave_Dict(self, node): - dad = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(dad, m.Assign(targets=[m.AssignTarget()])): - name = dad.targets[0].target.value + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): + name = parent.targets[0].target.value if name in self.assignments: for k in node.elements: dep_name = k.value.value @@ -132,9 +131,9 @@ class ClassFinder(CSTVisitor): if hasattr(node.decorator, "args"): for k in node.decorator.args: if k.value.value in self.assignments: - dad = self.get_metadata(cst.metadata.ParentNodeProvider, node) + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else dad.name.value + name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value self._update_class_dependency(name, k.value.value) def leave_Module(self, node): @@ -304,11 +303,9 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) - new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)).body[ - 0 - ] # get the indented block - - return original_node.with_changes(body=new_replacement_class.body) + new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods)) + new_replacement_body = new_replacement_class.body[0].body # get the indented block + return original_node.with_changes(body=new_replacement_body) class DiffConverterTransformer(CSTTransformer): @@ -327,13 +324,6 @@ class DiffConverterTransformer(CSTTransformer): self.global_scope_index = 0 # fmt: on - def leave_FunctionDef(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - self.global_scope_index += 100 - self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node} - return node - def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: 1. Get the original source code @@ -351,6 +341,13 @@ class DiffConverterTransformer(CSTTransformer): imported_class = self.python_module.code_for_node(imported_.name) self.imported_mapping[imported_class] = import_statement + def leave_FunctionDef(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + self.global_scope_index += 100 + self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node} + return node + def leave_SimpleStatementLine(self, original_node, updated_node): if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): self.all_imports.append(updated_node) @@ -423,15 +420,15 @@ class DiffConverterTransformer(CSTTransformer): self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} return updated_node - # def leave_If(self, original_node, node): - # parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - # if m.matches(parent_node, m.Module()): - # full_statement = self.python_module.code_for_node(original_node.test) - # if re.search(r"[\s\S]*is_.*available", full_statement): - # self.all_imports.append(node) - # else: - # self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node} - # return node + def leave_If(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + full_statement = self.python_module.code_for_node(original_node.test) + if re.search(r"[\s\S]*is_.*available", full_statement): + self.all_imports.append(node) + elif full_statement not in self.new_body: + self.new_body[node] = {"insert_idx":self.global_scope_index, "node":node} + return node def leave_Module(self, original_node: cst.Assign, node): imports = {self.python_module.code_for_node(k): k for k in self.all_imports}