transformers/examples/diff-conversion/diff_dummy.py

45 lines
1.4 KiB
Python

from math import log
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids
# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
input_ids = _pre_process_input(input_ids)
return super().forward(
None,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)