attempt diffs for 3 files
This commit is contained in:
parent
e08d8eb963
commit
92b6218e18
|
@ -0,0 +1,58 @@
|
|||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
GemmaConverter = ModelConverter(__file__)
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
GemmaConverter.register("GemmaRotaryEmbedding", LlamaRotaryEmbedding)
|
||||
GemmaConverter.register("GemmaMLP", LlamaMLP)
|
||||
|
||||
GemmaAttention = GemmaConverter.register("GemmaAttention", LlamaAttention)
|
||||
GemmaSdpaAttention = GemmaConverter.register("GemmaSdpaAttention", LlamaAttention)
|
||||
GemmaFlashAttention2 = GemmaConverter.register("GemmaFlashAttention2", LlamaAttention)
|
||||
|
||||
COHERE_ATTENTION_CLASSES = {"eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, "sdpa": GemmaSdpaAttention}
|
||||
|
||||
GemmaConverter.register("GemmaDecoderLayer", LlamaDecoderLayer)
|
||||
GemmaConverter.register("GemmaPreTrainedModel", LlamaPreTrainedModel)
|
||||
|
||||
class GemmaModel(LlamaModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
||||
|
||||
|
||||
GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM)
|
|
@ -0,0 +1,395 @@
|
|||
from typing import Tuple
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers import StableLmConfig
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
StableLmConverter = ModelConverter(__file__)
|
||||
|
||||
StableLmRMSNorm = StableLmConverter.register("StableLmRMSNorm", LlamaRMSNorm)
|
||||
StarcoderRotaryEmbedding = StableLmConverter.register("StarcoderRotaryEmbedding", LlamaRotaryEmbedding)
|
||||
StableLmMLP = StableLmConverter.register("StableLmMLP", LlamaMLP)
|
||||
|
||||
|
||||
class StableLmLayerNormPerHead(nn.Module):
|
||||
def __init__(self, dim, num_heads, eps=1e-5, bias=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
# Split along the num_heads axis to get per-head inputs
|
||||
# [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
|
||||
states_per_heads = torch.split(hidden_states, 1, dim=1)
|
||||
# Normalize and merge the heads back together
|
||||
return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
|
||||
|
||||
class StableLmAttention(LlamaAttention):
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
|
||||
super().__init__(config, layer_idx) # here call to super means
|
||||
# we should copy super
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
|
||||
self.k_layernorm = StableLmLayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Specific to RoPE models with partial rotation
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class StableLmSdpaAttention(StableLmAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Specific to RoPE models with partial rotation
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
class StableLmFlashAttention2(LlamaFlashAttention2):
|
||||
"""
|
||||
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout.p if self.training else 0.0
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
StableLm_ATTENTION_CLASSES = {"eager": StableLmAttention, "flash_attention_2": StableLmFlashAttention2, "sdpa": StableLmSdpaAttention}
|
||||
|
||||
class StableLmDecoderLayer(nn.Module):
|
||||
def __init__(self, config: StableLmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = StableLm_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
self.mlp = StableLmMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = None
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
|
||||
if self.use_parallel_residual:
|
||||
# x = x + attn(ln1(x)) + mlp(ln1(x))
|
||||
# Fully Connected
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
mlp_output = self.dropout(mlp_output)
|
||||
hidden_states = residual + self_attn_output + mlp_output
|
||||
else:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
residual = residual + self_attn_output
|
||||
# Fully Connected
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(residual))
|
||||
mlp_output = self.dropout(mlp_output)
|
||||
hidden_states = residual + mlp_output
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
StableLmPreTrainedModel = StableLmConverter.register("StableLmPreTrainedModel", LlamaPreTrainedModel)
|
||||
StableLmdModel = StableLmConverter.register("StableLmdModel", LlamaModel)
|
||||
StableLmForCausalLM = StableLmConverter.register("StableLmForCausalLM", LlamaForCausalLM)
|
||||
StableLmForSequenceClassification = StableLmConverter.register("StableLmForSequenceClassification", LlamaForSequenceClassification)
|
|
@ -0,0 +1,153 @@
|
|||
from typing import List, Tuple
|
||||
from torch import FloatTensor, LongTensor, Tensor
|
||||
from torch._C import FloatTensor, LongTensor
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers import Starcoder2Config
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
Starcoder2Converter = ModelConverter(__file__)
|
||||
|
||||
Starcoder2RMSNorm = Starcoder2Converter.register("Starcoder2RMSNorm", LlamaRMSNorm)
|
||||
StarcoderRotaryEmbedding = Starcoder2Converter.register("StarcoderRotaryEmbedding", LlamaRotaryEmbedding)
|
||||
|
||||
class Starcoder2MLP(nn.Module):
|
||||
def __init__(self, config: Starcoder2Config):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
|
||||
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.residual_dropout = config.residual_dropout
|
||||
|
||||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
|
||||
return hidden_states
|
||||
|
||||
# TODO either we support this, or we don't allow call to super?
|
||||
# if part of the super is used, then we are fucked. Let's restrict this to init?
|
||||
|
||||
# TODO if a class is not registered, the original should be copied with replaces?
|
||||
# Copied form where? No.
|
||||
# But then how do we check the architecture etc.
|
||||
|
||||
# TODO do we support multiple inheritance?
|
||||
# This will depend on whether we usually copy from more than one module
|
||||
# Mixtral for example?
|
||||
|
||||
class Starcoder2Attention(LlamaAttention):
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
|
||||
super().__init__(config, layer_idx) # here call to super means
|
||||
# we should copy super
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = self.rotary_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
Starcoder2SdpaAttention = Starcoder2Converter.register("Starcoder2SdpaAttention", LlamaAttention)
|
||||
Starcoder2FlashAttention2 = Starcoder2Converter.register("Starcoder2FlashAttention2", LlamaAttention)
|
||||
|
||||
STARCODER2_ATTENTION_CLASSES = {"eager": Starcoder2Attention, "flash_attention_2": Starcoder2FlashAttention2, "sdpa": Starcoder2SdpaAttention}
|
||||
|
||||
|
||||
Starcoder2DecoderLayer = Starcoder2Converter.register("Starcoder2DecoderLayer", LlamaDecoderLayer)
|
||||
Starcoder2PreTrainedModel = Starcoder2Converter.register("Starcoder2PreTrainedModel", LlamaPreTrainedModel)
|
||||
|
||||
class Starcoder2Model(LlamaModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embedding_dropout = config.embedding_dropout
|
||||
|
||||
def forward(self, input_ids: LongTensor = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: List[FloatTensor] | None = None, inputs_embeds: FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: LongTensor | None = None) -> Tuple | BaseModelOutputWithPast:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
|
||||
return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
||||
|
||||
Starcoder2ForCausalLM = Starcoder2Converter.register("Starcoder2ForCausalLM", LlamaForCausalLM)
|
||||
Starcoder2ForSequenceClassification = Starcoder2Converter.register("Starcoder2ForSequenceClassification", LlamaForSequenceClassification)
|
Loading…
Reference in New Issue