handle funtions

This commit is contained in:
Arthur Zucker 2024-05-17 20:51:34 +02:00
parent 768801cbac
commit 274ac8801d
2 changed files with 79 additions and 14 deletions

View File

@ -44,6 +44,64 @@ from ...utils import (
from .configuration_gemma import GemmaConfig
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
""" PyTorch Gemma model."""
import math
@ -530,6 +588,7 @@ class GemmaSdpaAttention(GemmaAttention):
return attn_output, None, past_key_value
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,

View File

@ -21,16 +21,19 @@ def get_module_source_from_name(module_name: str) -> str:
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst.metadata import MetadataWrapper, ParentNodeProvider
class ClassFinder(CSTVisitor):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module):
self.module = python_module
self.classes = {} # class LlamaAttentino
self.imports = {} # from flash_attn import
self.function_def = {} # def repeat_kv
self.classes = {} # class LlamaAttentino
self.imports = {} # from flash_attn import
self.function_def = {} # def repeat_kv
self.assignments = {} # LLAMA_DOCSTRING
self.protected_imports = {} # if is_xxx_available()
self.protected_imports = {} # if is_xxx_available()
def visit_ClassDef(self, node: ClassDef) -> None:
self.classes[node.name.value] = node
@ -53,8 +56,10 @@ class ClassFinder(CSTVisitor):
self.imports[node.body[0].names] = node
case cst.SimpleStatementLine(body=[cst.ImportFrom(_)]):
self.imports[node.body[0].names] = node
case cst.SimpleStatementLine(boyd=[cst.FunctionDef(_)]):
self.function_def[node.name.value] = node
def visit_FunctionDef(self, node):
if isinstance(self.get_metadata(cst.metadata.ParentNodeProvider, node), cst.Module):
self.function_def[node.name.value] = node
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
@ -75,8 +80,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return self.new_name.lower()
else:
return self.new_name.title()
return self.regex.sub(replace, text)
return self.regex.sub(replace, text)
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
@ -85,10 +90,13 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
print(f"changed: {updated_node.value} -> {update}")
return updated_node.with_changes(value=self.preserve_case_replace(updated_node.value))
def find_classes_in_file(module, old_id="llama", new_id="gemma"):
transformer = ReplaceNameTransformer(old_id, new_id)
new_module = module.visit(transformer)
new_module = MetadataWrapper(new_module)
class_finder = ClassFinder(new_module)
new_module.visit(class_finder)
return class_finder
@ -122,7 +130,9 @@ class DiffConverterTransformer(CSTTransformer):
if parent_package not in self.visited_module:
class_finder = find_classes_in_file(self.transformers_imports[parent_package])
self.visited_module[parent_package] = class_finder
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[parent_package].classes[node.targets[0].target.value]
self.class_mapping[self.python_module.code_for_node(node)] = self.visited_module[
parent_package
].classes[node.targets[0].target.value]
def leave_SimpleStatementLine(self, original_node: cst.Assign, updated_node: cst.CSTNode):
match updated_node:
@ -151,12 +161,8 @@ class DiffConverterTransformer(CSTTransformer):
new_body += list(visiter.assignments.values())
new_body += list(visiter.function_def.values())
return node.with_changes(
body=[
*new_body,
*node.body
]
)
return node.with_changes(body=[*new_body, *node.body])
if __name__ == "__main__":
# Parse the Python file