39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from transformers import Cache
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
|
|
|
|
# example where we need some deps and some functions
|
|
class SuperModel(LlamaModel):
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
out = super().forward(
|
|
input_ids,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
inputs_embeds,
|
|
use_cache,
|
|
output_attentions,
|
|
output_hidden_states,
|
|
return_dict,
|
|
cache_position,
|
|
)
|
|
out.logits *= 2**4
|
|
return out
|