remove other diff files
This commit is contained in:
parent
53a4ce871c
commit
0c7e43eb8a
|
@ -1,141 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Cohere team. All rights reserved.
|
|
||||||
#
|
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
||||||
# and OPT implementations in this library. It has been modified from its
|
|
||||||
# original forms to accommodate minor architectural differences compared
|
|
||||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# This file is based on the LLama model definition file in transformers
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, LlamaMLP, LlamaPreTrainedModel, LlamaModel, LlamaForCausalLM
|
|
||||||
import torch
|
|
||||||
from typing import *
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import CohereConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CohereLayerNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
|
|
||||||
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
mean = hidden_states.mean(-1, keepdim=True)
|
|
||||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
|
||||||
return hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
class CohereRotaryEmbedding(LlamaRotaryEmbedding):
|
|
||||||
|
|
||||||
def rotate_half(self, x):
|
|
||||||
# Split and rotate
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
|
||||||
return rot_x
|
|
||||||
|
|
||||||
def forward(self, q, k, position_ids=None, unsqueeze_dim=1):
|
|
||||||
dtype = q.dtype
|
|
||||||
q,k = q.float(), k.float()
|
|
||||||
cos, sin = self.comput_cos_sin(q, position_ids)
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
|
||||||
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
|
||||||
|
|
||||||
CohereMLP = LlamaMLP
|
|
||||||
CohereAttention = LlamaAttention
|
|
||||||
CohereSdpaAttention = LlamaAttention
|
|
||||||
CohereFlashAttention2 = LlamaAttention
|
|
||||||
|
|
||||||
COHERE_ATTENTION_CLASSES = {"eager": CohereAttention, "flash_attention_2": CohereFlashAttention2, "sdpa": CohereSdpaAttention}
|
|
||||||
|
|
||||||
class CohereDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config: CohereConfig, layer_idx: int):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
|
|
||||||
self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
|
||||||
|
|
||||||
self.mlp = CohereMLP(config)
|
|
||||||
self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
hidden_states_mlp = self.mlp(hidden_states)
|
|
||||||
|
|
||||||
# Add everything together (main diff with llama )
|
|
||||||
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
CoherePreTrainedModel = LlamaPreTrainedModel
|
|
||||||
CohereModel = LlamaModel
|
|
||||||
|
|
||||||
from typing import TypedDict, Unpack
|
|
||||||
|
|
||||||
|
|
||||||
class CohereForCausalLM(LlamaForCausalLM):
|
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.logit_scale = config.logit_scale
|
|
||||||
self.tie_word_embeddings = config.tie_word_embeddings
|
|
||||||
|
|
||||||
def forward(self, **kwargs):
|
|
||||||
output = super().forward(**kwargs)
|
|
||||||
logits = self.lm_head(output[1])
|
|
||||||
logits = logits * self.logit_scale
|
|
|
@ -1,150 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
from torch import FloatTensor, LongTensor, Tensor
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
||||||
import torch
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding, LlamaAttention, ACT2FN, LlamaDecoderLayer, LlamaPreTrainedModel, LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import Starcoder2Config
|
|
||||||
|
|
||||||
|
|
||||||
Starcoder2RMSNorm = LlamaRMSNorm
|
|
||||||
Starcoder2RotaryEmbedding = LlamaRotaryEmbedding
|
|
||||||
|
|
||||||
class Starcoder2MLP(nn.Module):
|
|
||||||
def __init__(self, config: Starcoder2Config):
|
|
||||||
super().__init__()
|
|
||||||
embed_dim = config.hidden_size
|
|
||||||
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
|
|
||||||
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
|
|
||||||
self.act = ACT2FN[config.hidden_act]
|
|
||||||
self.residual_dropout = config.residual_dropout
|
|
||||||
|
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
|
||||||
hidden_states = self.c_fc(hidden_states)
|
|
||||||
hidden_states = self.act(hidden_states)
|
|
||||||
hidden_states = self.c_proj(hidden_states)
|
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Attention(LlamaAttention):
|
|
||||||
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
|
|
||||||
super().__init__(config, layer_idx)
|
|
||||||
self.attention_dropout = config.attention_dropout
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
if self.layer_idx is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
||||||
"with a layer index."
|
|
||||||
)
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = self.rotary_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
Starcoder2SdpaAttention = LlamaAttention
|
|
||||||
Starcoder2FlashAttention2 = LlamaAttention
|
|
||||||
|
|
||||||
STARCODER2_ATTENTION_CLASSES = {"eager": Starcoder2Attention, "flash_attention_2": Starcoder2FlashAttention2, "sdpa": Starcoder2SdpaAttention}
|
|
||||||
|
|
||||||
|
|
||||||
Starcoder2DecoderLayer = LlamaDecoderLayer
|
|
||||||
Starcoder2PreTrainedModel = LlamaPreTrainedModel
|
|
||||||
|
|
||||||
class Starcoder2Model(LlamaModel):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.embedding_dropout = config.embedding_dropout
|
|
||||||
|
|
||||||
def forward(self, input_ids: LongTensor = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: List[FloatTensor] | None = None, inputs_embeds: FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: LongTensor | None = None) -> Tuple | BaseModelOutputWithPast:
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
|
|
||||||
return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
|
||||||
|
|
||||||
Starcoder2ForCausalLM = LlamaForCausalLM
|
|
||||||
Starcoder2ForSequenceClassification = LlamaForSequenceClassification
|
|
Loading…
Reference in New Issue