updates
This commit is contained in:
parent
8fe406fd17
commit
d5c00047da
|
@ -612,116 +612,118 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||||
hidden_states = hidden_states * normalizer
|
hidden_states = hidden_states * normalizer
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
return (None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position
|
||||||
def forward(
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
self,
|
def forward(
|
||||||
input_ids: torch.LongTensor = None,
|
self,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
input_ids: torch.LongTensor = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_hidden_states = (
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states = (
|
||||||
)
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
logger.warning_once(
|
raise ValueError(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
)
|
)
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
return_legacy_cache = False
|
if inputs_embeds is None:
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
return_legacy_cache = True
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
|
|
||||||
if cache_position is None:
|
return_legacy_cache = False
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
cache_position = torch.arange(
|
return_legacy_cache = True
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
# embed positions
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
# embed positions
|
||||||
all_hidden_states = () if output_hidden_states else None
|
hidden_states = inputs_embeds
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
if return_legacy_cache:
|
||||||
decoder_layer.__call__,
|
next_cache = next_cache.to_legacy_cache()
|
||||||
hidden_states,
|
|
||||||
causal_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
if use_cache:
|
return BaseModelOutputWithPast(
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
if output_attentions:
|
hidden_states=all_hidden_states,
|
||||||
all_self_attns += (layer_outputs[1],)
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
hidden_states = self.norm(hidden_states)
|
)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if return_legacy_cache:
|
|
||||||
next_cache = next_cache.to_legacy_cache()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
|
||||||
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -127,7 +127,39 @@ class Starcoder2MLP(nn.Module):
|
||||||
class Starcoder2Attention(nn.Module):
|
class Starcoder2Attention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
|
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
|
||||||
super().__init__(config, layer_idx) # here call to super means
|
config, layer_idx
|
||||||
|
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
||||||
|
self._init_rope() # here call to super means
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
def _init_rope(self):
|
def _init_rope(self):
|
||||||
if self.config.rope_scaling is None:
|
if self.config.rope_scaling is None:
|
||||||
|
@ -673,7 +705,21 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||||
config: LlamaConfig
|
config: LlamaConfig
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
config
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
self.embedding_dropout = config.embedding_dropout
|
self.embedding_dropout = config.embedding_dropout
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
|
@ -684,116 +730,118 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
return (None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position
|
||||||
def forward(
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
self,
|
def forward(
|
||||||
input_ids: torch.LongTensor = None,
|
self,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
input_ids: torch.LongTensor = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_hidden_states = (
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states = (
|
||||||
)
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
logger.warning_once(
|
raise ValueError(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
)
|
)
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
return_legacy_cache = False
|
if inputs_embeds is None:
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
return_legacy_cache = True
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
|
|
||||||
if cache_position is None:
|
return_legacy_cache = False
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
cache_position = torch.arange(
|
return_legacy_cache = True
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
# embed positions
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
# embed positions
|
||||||
all_hidden_states = () if output_hidden_states else None
|
hidden_states = inputs_embeds
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
if return_legacy_cache:
|
||||||
decoder_layer.__call__,
|
next_cache = next_cache.to_legacy_cache()
|
||||||
hidden_states,
|
|
||||||
causal_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
if use_cache:
|
return BaseModelOutputWithPast(
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
if output_attentions:
|
hidden_states=all_hidden_states,
|
||||||
all_self_attns += (layer_outputs[1],)
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
hidden_states = self.norm(hidden_states)
|
)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if return_legacy_cache:
|
|
||||||
next_cache = next_cache.to_legacy_cache()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
|
||||||
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -17,7 +17,19 @@ import inspect
|
||||||
# with open(CohereConverter.original_file, 'r') as file, open("result.py", "w+") as modeling:
|
# with open(CohereConverter.original_file, 'r') as file, open("result.py", "w+") as modeling:
|
||||||
# pass
|
# pass
|
||||||
# TODO also copy and import from all modules in CohereConverter.modules_to_import to be able to use inspect
|
# TODO also copy and import from all modules in CohereConverter.modules_to_import to be able to use inspect
|
||||||
import sys
|
def replace_super_calls_in_method(method_body, parent_method_body, method_name):
|
||||||
|
# Indent parent method body to match the child's method indentation
|
||||||
|
indent = re.match(r'(\s*)def ', method_body).group(1)
|
||||||
|
indented_parent_method_body = "\n".join([indent + line if line.strip() else line for line in parent_method_body.split('\n')])
|
||||||
|
method_name = method_name.strip()
|
||||||
|
# Handle super().method_name(args) and return super().method_name(args)
|
||||||
|
super_call_pattern = re.compile(r'(\s*)return super\(\)\.' + method_name + r'\((.*?)\)')
|
||||||
|
method_body = super_call_pattern.sub(r'\1return (\2\n' + indented_parent_method_body + r'\1)', method_body)
|
||||||
|
|
||||||
|
super_call_pattern_no_return = re.compile(r'(\s*)super\(\)\.' + method_name + r'\((.*?)\)')
|
||||||
|
method_body = super_call_pattern_no_return.sub(r'\1\2\n' + indented_parent_method_body, method_body)
|
||||||
|
|
||||||
|
return method_body
|
||||||
# 2. Write all the classes. Use the `CohereConverter` class for this.
|
# 2. Write all the classes. Use the `CohereConverter` class for this.
|
||||||
def create_single_model_file(converter):
|
def create_single_model_file(converter):
|
||||||
model_identifier = converter.diff_file.split("diff_")
|
model_identifier = converter.diff_file.split("diff_")
|
||||||
|
@ -71,7 +83,15 @@ def create_single_model_file(converter):
|
||||||
# TODO handle call to super!
|
# TODO handle call to super!
|
||||||
full_function = match.group()
|
full_function = match.group()
|
||||||
if "def" in full_function:
|
if "def" in full_function:
|
||||||
child_function_set[full_function.split("def")[1].split("(")[0]] = re.sub(r" return super\(\).forward\(", parent_function_set.get(function_name,""), full_function)
|
function_name = full_function.split("def")[1].split("(")[0]
|
||||||
|
if f"super()." in full_function or f"return super()." in full_function:
|
||||||
|
replaced_function = replace_super_calls_in_method(full_function,
|
||||||
|
parent_function_set.get(function_name,
|
||||||
|
""),
|
||||||
|
function_name)
|
||||||
|
child_function_set[function_name] = replaced_function
|
||||||
|
else:
|
||||||
|
child_function_set[function_name] = full_function
|
||||||
else:
|
else:
|
||||||
child_function_set[full_function] = full_function
|
child_function_set[full_function] = full_function
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue