From 274ac8801d318da36003dd9ec6f7d1f122e11e00 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 17 May 2024 20:51:34 +0200 Subject: [PATCH] handle funtions --- .../models/gemma/modeling_gemma.py | 59 +++++++++++++++++++ utils/diff_model_converter.py | 34 ++++++----- 2 files changed, 79 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2b79ef71e8..98e591a0f9 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -44,6 +44,64 @@ from ...utils import ( from .configuration_gemma import GemmaConfig +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + """ PyTorch Gemma model.""" import math @@ -530,6 +588,7 @@ class GemmaSdpaAttention(GemmaAttention): return attn_output, None, past_key_value + GEMMA_ATTENTION_CLASSES = { "eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, diff --git a/utils/diff_model_converter.py b/utils/diff_model_converter.py index 4a51955119..ca9de8db57 100644 --- a/utils/diff_model_converter.py +++ b/utils/diff_model_converter.py @@ -21,16 +21,19 @@ def get_module_source_from_name(module_name: str) -> str: from libcst import ClassDef, CSTTransformer, CSTVisitor +from libcst.metadata import MetadataWrapper, ParentNodeProvider class ClassFinder(CSTVisitor): + METADATA_DEPENDENCIES = (ParentNodeProvider,) + def __init__(self, python_module): self.module = python_module - self.classes = {} # class LlamaAttentino - self.imports = {} # from flash_attn import - self.function_def = {} # def repeat_kv + self.classes = {} # class LlamaAttentino + self.imports = {} # from flash_attn import + self.function_def = {} # def repeat_kv self.assignments = {} # LLAMA_DOCSTRING - self.protected_imports = {} # if is_xxx_available() + self.protected_imports = {} # if is_xxx_available() def visit_ClassDef(self, node: ClassDef) -> None: self.classes[node.name.value] = node @@ -53,8 +56,10 @@ class ClassFinder(CSTVisitor): self.imports[node.body[0].names] = node case cst.SimpleStatementLine(body=[cst.ImportFrom(_)]): self.imports[node.body[0].names] = node - case cst.SimpleStatementLine(boyd=[cst.FunctionDef(_)]): - self.function_def[node.name.value] = node + + def visit_FunctionDef(self, node): + if isinstance(self.get_metadata(cst.metadata.ParentNodeProvider, node), cst.Module): + self.function_def[node.name.value] = node class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -75,8 +80,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): return self.new_name.lower() else: return self.new_name.title() - return self.regex.sub(replace, text) + return self.regex.sub(replace, text) @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): @@ -85,10 +90,13 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): print(f"changed: {updated_node.value} -> {update}") return updated_node.with_changes(value=self.preserve_case_replace(updated_node.value)) + def find_classes_in_file(module, old_id="llama", new_id="gemma"): transformer = ReplaceNameTransformer(old_id, new_id) new_module = module.visit(transformer) + new_module = MetadataWrapper(new_module) + class_finder = ClassFinder(new_module) new_module.visit(class_finder) return class_finder @@ -122,7 +130,9 @@ class DiffConverterTransformer(CSTTransformer): if parent_package not in self.visited_module: class_finder = find_classes_in_file(self.transformers_imports[parent_package]) self.visited_module[parent_package] = class_finder - self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[parent_package].classes[node.targets[0].target.value] + self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[ + parent_package + ].classes[node.targets[0].target.value] def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode): match updated_node: @@ -151,12 +161,8 @@ class DiffConverterTransformer(CSTTransformer): new_body += list(visiter.assignments.values()) new_body += list(visiter.function_def.values()) - return node.with_changes( - body=[ - *new_body, - *node.body - ] - ) + return node.with_changes(body=[*new_body, *node.body]) + if __name__ == "__main__": # Parse the Python file