45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
from math import log
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from transformers import Cache
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
|
|
|
|
def _pre_process_input(input_ids):
|
|
print(log(input_ids))
|
|
return input_ids
|
|
|
|
|
|
# example where we need some deps and some functions
|
|
class DummyModel(LlamaModel):
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
input_ids = _pre_process_input(input_ids)
|
|
|
|
return super().forward(
|
|
None,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
inputs_embeds,
|
|
use_cache,
|
|
output_attentions,
|
|
output_hidden_states,
|
|
return_dict,
|
|
cache_position,
|
|
)
|