Compare commits

...

1 Commits

Author SHA1 Message Date
Arthur Zucker b66fdb0a1a dumb commit 2024-05-31 13:24:43 +02:00
1 changed files with 18 additions and 10 deletions

View File

@ -855,6 +855,23 @@ LLAMA_INPUTS_DOCSTRING = r"""
the complete sequence length. the complete sequence length.
""" """
from typing import TypedDict,Unpack
class ForwardParameters(TypedDict):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
class TGIForwardParameters(ForwardParameters):
cu_ceq_lens:torch.tensor
@add_start_docstrings( @add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.", "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
@ -892,16 +909,7 @@ class LlamaModel(LlamaPreTrainedModel):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, **kwargs:Unpack[TGIForwardParameters]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (